diff --git a/adv_control/control.py b/adv_control/control.py index 112bbe1..9cc4497 100644 --- a/adv_control/control.py +++ b/adv_control/control.py @@ -12,7 +12,7 @@ from comfy.model_patcher import ModelPatcher from .control_sparsectrl import SparseModelPatcher, SparseControlNet, SparseCtrlMotionWrapper, SparseSettings, SparseConst -from .control_lllite import LLLiteModule, LLLitePatch +from .control_lllite import LLLiteModule, LLLitePatch, load_controllllite from .control_svd import svd_unet_config_from_diffusers_unet, SVDControlNet, svd_unet_to_diffusers from .utils import (AdvancedControlBase, TimestepKeyframeGroup, LatentKeyframeGroup, AbstractPreprocWrapper, ControlWeightType, ControlWeights, WeightTypeException, manual_cast_clean_groupnorm, disable_weight_init_clean_groupnorm, prepare_mask_batch, get_properly_arranged_t2i_weights, load_torch_file_with_dict_factory, @@ -20,6 +20,9 @@ from .logger import logger +ORIG_PREVIOUS_CONTROLNET = "_orig_previous_controlnet" + + class ControlNetAdvanced(ControlNet, AdvancedControlBase): def __init__(self, control_model, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None): super().__init__(control_model=control_model, global_average_pooling=global_average_pooling, compression_ratio=compression_ratio, latent_format=latent_format, device=device, load_device=load_device, manual_cast_dtype=manual_cast_dtype) @@ -98,8 +101,10 @@ def copy(self): @staticmethod def from_vanilla(v: ControlNet, timestep_keyframe: TimestepKeyframeGroup=None) -> 'ControlNetAdvanced': - return ControlNetAdvanced(control_model=v.control_model, timestep_keyframes=timestep_keyframe, + to_return = ControlNetAdvanced(control_model=v.control_model, timestep_keyframes=timestep_keyframe, global_average_pooling=v.global_average_pooling, compression_ratio=v.compression_ratio, latent_format=v.latent_format, device=v.device, load_device=v.load_device, manual_cast_dtype=v.manual_cast_dtype) + v.copy_to(to_return) + return to_return class T2IAdapterAdvanced(T2IAdapter, AdvancedControlBase): @@ -166,8 +171,10 @@ def cleanup(self): @staticmethod def from_vanilla(v: T2IAdapter, timestep_keyframe: TimestepKeyframeGroup=None) -> 'T2IAdapterAdvanced': - return T2IAdapterAdvanced(t2i_model=v.t2i_model, timestep_keyframes=timestep_keyframe, channels_in=v.channels_in, + to_return = T2IAdapterAdvanced(t2i_model=v.t2i_model, timestep_keyframes=timestep_keyframe, channels_in=v.channels_in, compression_ratio=v.compression_ratio, upscale_algorithm=v.upscale_algorithm, device=v.device) + v.copy_to(to_return) + return to_return class ControlLoraAdvanced(ControlLora, AdvancedControlBase): @@ -194,8 +201,10 @@ def cleanup(self): @staticmethod def from_vanilla(v: ControlLora, timestep_keyframe: TimestepKeyframeGroup=None) -> 'ControlLoraAdvanced': - return ControlLoraAdvanced(control_weights=v.control_weights, timestep_keyframes=timestep_keyframe, + to_return = ControlLoraAdvanced(control_weights=v.control_weights, timestep_keyframes=timestep_keyframe, global_average_pooling=v.global_average_pooling, device=v.device) + v.copy_to(to_return) + return to_return class SVDControlNetAdvanced(ControlNetAdvanced): @@ -408,115 +417,6 @@ def copy(self): return c -class ControlLLLiteAdvanced(ControlBase, AdvancedControlBase): - # This ControlNet is more of an attention patch than a traditional controlnet - def __init__(self, patch_attn1: LLLitePatch, patch_attn2: LLLitePatch, timestep_keyframes: TimestepKeyframeGroup, device=None): - super().__init__(device) - AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllllite(), require_model=True) - self.patch_attn1 = patch_attn1.set_control(self) - self.patch_attn2 = patch_attn2.set_control(self) - self.latent_dims_div2 = None - self.latent_dims_div4 = None - - def patch_model(self, model: ModelPatcher): - model.set_model_attn1_patch(self.patch_attn1) - model.set_model_attn2_patch(self.patch_attn2) - - def set_cond_hint_inject(self, *args, **kwargs): - to_return = super().set_cond_hint_inject(*args, **kwargs) - # cond hint for LLLite needs to be scaled between (-1, 1) instead of (0, 1) - self.cond_hint_original = self.cond_hint_original * 2.0 - 1.0 - return to_return - - def pre_run_advanced(self, *args, **kwargs): - AdvancedControlBase.pre_run_advanced(self, *args, **kwargs) - #logger.error(f"in cn: {id(self.patch_attn1)},{id(self.patch_attn2)}") - self.patch_attn1.set_control(self) - self.patch_attn2.set_control(self) - #logger.warn(f"in pre_run_advanced: {id(self)}") - - def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int): - # normal ControlNet stuff - control_prev = None - if self.previous_controlnet is not None: - control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) - - if self.timestep_range is not None: - if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: - return control_prev - - dtype = x_noisy.dtype - # prepare cond_hint - if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: - if self.cond_hint is not None: - del self.cond_hint - self.cond_hint = None - # if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling - if self.sub_idxs is not None: - actual_cond_hint_orig = self.cond_hint_original - if self.cond_hint_original.size(0) < self.full_latent_length: - actual_cond_hint_orig = extend_to_batch_size(tensor=actual_cond_hint_orig, batch_size=self.full_latent_length) - self.cond_hint = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device) - else: - self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device) - if x_noisy.shape[0] != self.cond_hint.shape[0]: - self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number) - # some special logic here compared to other controlnets: - # * The cond_emb in attn patches will divide latent dims by 2 or 4, integer - # * Due to this loss, the cond_emb will become smaller than x input if latent dims are not divisble by 2 or 4 - divisible_by_2_h = x_noisy.shape[2]%2==0 - divisible_by_2_w = x_noisy.shape[3]%2==0 - if not (divisible_by_2_h and divisible_by_2_w): - #logger.warn(f"{x_noisy.shape} not divisible by 2!") - new_h = (x_noisy.shape[2]//2)*2 - new_w = (x_noisy.shape[3]//2)*2 - if not divisible_by_2_h: - new_h += 2 - if not divisible_by_2_w: - new_w += 2 - self.latent_dims_div2 = (new_h, new_w) - divisible_by_4_h = x_noisy.shape[2]%4==0 - divisible_by_4_w = x_noisy.shape[3]%4==0 - if not (divisible_by_4_h and divisible_by_4_w): - #logger.warn(f"{x_noisy.shape} not divisible by 4!") - new_h = (x_noisy.shape[2]//4)*4 - new_w = (x_noisy.shape[3]//4)*4 - if not divisible_by_4_h: - new_h += 4 - if not divisible_by_4_w: - new_w += 4 - self.latent_dims_div4 = (new_h, new_w) - # prepare mask - self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number) - # done preparing; model patches will take care of everything now. - # return normal controlnet stuff - return control_prev - - def cleanup_advanced(self): - super().cleanup_advanced() - self.patch_attn1.cleanup() - self.patch_attn2.cleanup() - self.latent_dims_div2 = None - self.latent_dims_div4 = None - - def copy(self): - c = ControlLLLiteAdvanced(self.patch_attn1, self.patch_attn2, self.timestep_keyframes) - self.copy_to(c) - self.copy_to_advanced(c) - return c - - # deepcopy needs to properly keep track of objects to work between model.clone calls! - # def __deepcopy__(self, *args, **kwargs): - # self.cleanup_advanced() - # return self - - # def get_models(self): - # # get_models is called once at the start of every KSampler run - use to reset already_patched status - # out = super().get_models() - # logger.error(f"in get_models! {id(self)}") - # return out - - def load_controlnet(ckpt_path, timestep_keyframe: TimestepKeyframeGroup=None, model=None): controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) # from pathlib import Path @@ -591,6 +491,112 @@ def convert_to_advanced(control, timestep_keyframe: TimestepKeyframeGroup=None): return control +def convert_all_to_advanced(conds: list[list[dict[str]]]) -> tuple[bool, list]: + cache = {} + modified = False + new_conds = [] + for cond in conds: + converted_cond = None + if cond is not None: + need_to_convert = False + # first, check if there is even a need to convert + for sub_cond in cond: + actual_cond = sub_cond[1] + if "control" in actual_cond: + if not are_all_advanced_controlnet(actual_cond["control"]): + need_to_convert = True + break + if not need_to_convert: + converted_cond = cond + else: + converted_cond = [] + for sub_cond in cond: + new_sub_cond: list = [] + for actual_cond in sub_cond: + if not type(actual_cond) == dict: + new_sub_cond.append(actual_cond) + continue + if "control" not in actual_cond: + new_sub_cond.append(actual_cond) + elif are_all_advanced_controlnet(actual_cond["control"]): + new_sub_cond.append(actual_cond) + else: + actual_cond = actual_cond.copy() + actual_cond["control"] = _convert_all_control_to_advanced(actual_cond["control"], cache) + new_sub_cond.append(actual_cond) + modified = True + converted_cond.append(new_sub_cond) + new_conds.append(converted_cond) + return modified, new_conds + + +def _convert_all_control_to_advanced(input_object: ControlBase, cache: dict): + output_object = input_object + # iteratively convert to advanced, if needed + next_cn = None + curr_cn = input_object + iter = 0 + while curr_cn is not None: + if not is_advanced_controlnet(curr_cn): + # if already in cache, then conversion was done before, so just link it and exit + if curr_cn in cache: + new_cn = cache[curr_cn] + if next_cn is not None: + setattr(next_cn, ORIG_PREVIOUS_CONTROLNET, next_cn.previous_controlnet) + next_cn.previous_controlnet = new_cn + if iter == 0: # if was top-level controlnet, that's the new output + output_object = new_cn + break + try: + # convert to advanced, and assign previous_controlnet (convert doesn't transfer it) + new_cn = convert_to_advanced(curr_cn) + except Exception as e: + raise Exception("Failed to automatically convert a ControlNet to Advanced to support sliding window context.", e) + new_cn.previous_controlnet = curr_cn.previous_controlnet + if iter == 0: # if was top-level controlnet, that's the new output + output_object = new_cn + # if next_cn is present, then it needs to be pointed to new_cn + if next_cn is not None: + setattr(next_cn, ORIG_PREVIOUS_CONTROLNET, next_cn.previous_controlnet) + next_cn.previous_controlnet = new_cn + # add to cache + cache[curr_cn] = new_cn + curr_cn = new_cn + next_cn = curr_cn + curr_cn = curr_cn.previous_controlnet + iter += 1 + return output_object + + +def restore_all_controlnet_conns(conds: list[list[dict[str]]]): + # if a cn has an _orig_previous_controlnet property, restore it and delete + for main_cond in conds: + if main_cond is not None: + for cond in main_cond: + if "control" in cond[1]: + _restore_all_controlnet_conns(cond[1]["control"]) + + +def _restore_all_controlnet_conns(input_object: ControlBase): + # restore original previous_controlnet if needed + curr_cn = input_object + while curr_cn is not None: + if hasattr(curr_cn, ORIG_PREVIOUS_CONTROLNET): + curr_cn.previous_controlnet = getattr(curr_cn, ORIG_PREVIOUS_CONTROLNET) + delattr(curr_cn, ORIG_PREVIOUS_CONTROLNET) + curr_cn = curr_cn.previous_controlnet + + +def are_all_advanced_controlnet(input_object: ControlBase): + # iteratively check if linked controlnets objects are all advanced + curr_cn = input_object + while curr_cn is not None: + if not is_advanced_controlnet(curr_cn): + return False + curr_cn = curr_cn.previous_controlnet + return True + + def is_advanced_controlnet(input_object): return hasattr(input_object, "sub_idxs") @@ -749,55 +755,6 @@ class WeightsLoader(torch.nn.Module): return control -def load_controllllite(ckpt_path: str, controlnet_data: dict[str, Tensor]=None, timestep_keyframe: TimestepKeyframeGroup=None): - if controlnet_data is None: - controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) - # adapted from https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI - # first, split weights for each module - module_weights = {} - for key, value in controlnet_data.items(): - fragments = key.split(".") - module_name = fragments[0] - weight_name = ".".join(fragments[1:]) - - if module_name not in module_weights: - module_weights[module_name] = {} - module_weights[module_name][weight_name] = value - - # next, load each module - modules = {} - for module_name, weights in module_weights.items(): - # kohya planned to do something about how these should be chosen, so I'm not touching this - # since I am not familiar with the logic for this - if "conditioning1.4.weight" in weights: - depth = 3 - elif weights["conditioning1.2.weight"].shape[-1] == 4: - depth = 2 - else: - depth = 1 - - module = LLLiteModule( - name=module_name, - is_conv2d=weights["down.0.weight"].ndim == 4, - in_dim=weights["down.0.weight"].shape[1], - depth=depth, - cond_emb_dim=weights["conditioning1.0.weight"].shape[0] * 2, - mlp_dim=weights["down.0.weight"].shape[0], - ) - # load weights into module - module.load_state_dict(weights) - modules[module_name] = module - if len(modules) == 1: - module.is_first = True - - #logger.info(f"loaded {ckpt_path} successfully, {len(modules)} modules") - - patch_attn1 = LLLitePatch(modules=modules, patch_type=LLLitePatch.ATTN1) - patch_attn2 = LLLitePatch(modules=modules, patch_type=LLLitePatch.ATTN2) - control = ControlLLLiteAdvanced(patch_attn1=patch_attn1, patch_attn2=patch_attn2, timestep_keyframes=timestep_keyframe) - return control - - def load_svdcontrolnet(ckpt_path: str, controlnet_data: dict[str, Tensor]=None, timestep_keyframe: TimestepKeyframeGroup=None, model=None): if controlnet_data is None: controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) diff --git a/adv_control/control_lllite.py b/adv_control/control_lllite.py index 96cb471..4d9dcc1 100644 --- a/adv_control/control_lllite.py +++ b/adv_control/control_lllite.py @@ -8,10 +8,33 @@ import os import comfy.utils +import comfy.ops +import comfy.model_management +from comfy.model_patcher import ModelPatcher from comfy.controlnet import ControlBase from .logger import logger -from .utils import AdvancedControlBase, deepcopy_with_sharing, prepare_mask_batch +from .utils import (AdvancedControlBase, TimestepKeyframeGroup, ControlWeights, broadcast_image_to_extend, extend_to_batch_size, + deepcopy_with_sharing, prepare_mask_batch) + + +# based on set_model_patch code in comfy/model_patcher.py +def set_model_patch(model_options, patch, name): + to = model_options["transformer_options"] + # check if patch was already added + if "patches" in to: + current_patches = to["patches"].get(name, []) + if patch in current_patches: + return + if "patches" not in to: + to["patches"] = {} + to["patches"][name] = to["patches"].get(name, []) + [patch] + +def set_model_attn1_patch(model_options, patch): + set_model_patch(model_options, patch, "attn1_patch") + +def set_model_attn2_patch(model_options, patch): + set_model_patch(model_options, patch, "attn2_patch") def extra_options_to_module_prefix(extra_options): @@ -100,18 +123,18 @@ def cleanup(self): #logger.error(f"cleanup LLLitePatch: {id(self)}") # make sure deepcopy does not copy control, and deepcopied LLLitePatch should be assigned to control - def __deepcopy__(self, memo): - self.cleanup() - to_return: LLLitePatch = deepcopy_with_sharing(self, shared_attribute_names = ['control'], memo=memo) - #logger.warn(f"patch {id(self)} turned into {id(to_return)}") - try: - if self.patch_type == self.ATTN1: - to_return.control.patch_attn1 = to_return - elif self.patch_type == self.ATTN2: - to_return.control.patch_attn2 = to_return - except Exception: - pass - return to_return + # def __deepcopy__(self, memo): + # self.cleanup() + # to_return: LLLitePatch = deepcopy_with_sharing(self, shared_attribute_names = ['control'], memo=memo) + # #logger.warn(f"patch {id(self)} turned into {id(to_return)}") + # try: + # if self.patch_type == self.ATTN1: + # to_return.control.patch_attn1 = to_return + # elif self.patch_type == self.ATTN2: + # to_return.control.patch_attn2 = to_return + # except Exception: + # pass + # return to_return # TODO: use comfy.ops to support fp8 properly @@ -252,3 +275,188 @@ def forward(self, x: Tensor, control: Union[AdvancedControlBase, ControlBase]): if cond_type == 1: cx[actual_length*idx:actual_length*(idx+1)] *= control.weights.uncond_multiplier return cx * mask * control.strength * control._current_timestep_keyframe.strength + + +class ControlLLLiteModules(torch.nn.Module): + def __init__(self, patch_attn1: LLLitePatch, patch_attn2: LLLitePatch): + super().__init__() + self.patch_attn1_modules = torch.nn.Sequential(*list(patch_attn1.modules.values())) + self.patch_attn2_modules = torch.nn.Sequential(*list(patch_attn2.modules.values())) + + +class ControlLLLiteAdvanced(ControlBase, AdvancedControlBase): + # This ControlNet is more of an attention patch than a traditional controlnet + def __init__(self, patch_attn1: LLLitePatch, patch_attn2: LLLitePatch, timestep_keyframes: TimestepKeyframeGroup, device, ops: comfy.ops.disable_weight_init): + super().__init__(device) + AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllllite()) + self.device = device + self.ops = ops + self.patch_attn1 = patch_attn1.clone_with_control(self) + self.patch_attn2 = patch_attn2.clone_with_control(self) + self.control_model = ControlLLLiteModules(self.patch_attn1, self.patch_attn2) + self.control_model_wrapped = ModelPatcher(self.control_model, load_device=device, offload_device=comfy.model_management.unet_offload_device()) + self.latent_dims_div2 = None + self.latent_dims_div4 = None + + def live_model_patches(self, model_options): + set_model_attn1_patch(model_options, self.patch_attn1.set_control(self)) + set_model_attn2_patch(model_options, self.patch_attn2.set_control(self)) + + # def patch_model(self, model: ModelPatcher): + # model.set_model_attn1_patch(self.patch_attn1) + # model.set_model_attn2_patch(self.patch_attn2) + + def set_cond_hint_inject(self, *args, **kwargs): + to_return = super().set_cond_hint_inject(*args, **kwargs) + # cond hint for LLLite needs to be scaled between (-1, 1) instead of (0, 1) + self.cond_hint_original = self.cond_hint_original * 2.0 - 1.0 + return to_return + + def pre_run_advanced(self, *args, **kwargs): + AdvancedControlBase.pre_run_advanced(self, *args, **kwargs) + #logger.error(f"in cn: {id(self.patch_attn1)},{id(self.patch_attn2)}") + self.patch_attn1.set_control(self) + self.patch_attn2.set_control(self) + #logger.warn(f"in pre_run_advanced: {id(self)}") + + def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int): + # normal ControlNet stuff + control_prev = None + if self.previous_controlnet is not None: + control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) + + if self.timestep_range is not None: + if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: + return control_prev + + dtype = x_noisy.dtype + # prepare cond_hint + if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: + if self.cond_hint is not None: + del self.cond_hint + self.cond_hint = None + # if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling + if self.sub_idxs is not None: + actual_cond_hint_orig = self.cond_hint_original + if self.cond_hint_original.size(0) < self.full_latent_length: + actual_cond_hint_orig = extend_to_batch_size(tensor=actual_cond_hint_orig, batch_size=self.full_latent_length) + self.cond_hint = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device) + else: + self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device) + if x_noisy.shape[0] != self.cond_hint.shape[0]: + self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number) + # some special logic here compared to other controlnets: + # * The cond_emb in attn patches will divide latent dims by 2 or 4, integer + # * Due to this loss, the cond_emb will become smaller than x input if latent dims are not divisble by 2 or 4 + divisible_by_2_h = x_noisy.shape[2]%2==0 + divisible_by_2_w = x_noisy.shape[3]%2==0 + if not (divisible_by_2_h and divisible_by_2_w): + #logger.warn(f"{x_noisy.shape} not divisible by 2!") + new_h = (x_noisy.shape[2]//2)*2 + new_w = (x_noisy.shape[3]//2)*2 + if not divisible_by_2_h: + new_h += 2 + if not divisible_by_2_w: + new_w += 2 + self.latent_dims_div2 = (new_h, new_w) + divisible_by_4_h = x_noisy.shape[2]%4==0 + divisible_by_4_w = x_noisy.shape[3]%4==0 + if not (divisible_by_4_h and divisible_by_4_w): + #logger.warn(f"{x_noisy.shape} not divisible by 4!") + new_h = (x_noisy.shape[2]//4)*4 + new_w = (x_noisy.shape[3]//4)*4 + if not divisible_by_4_h: + new_h += 4 + if not divisible_by_4_w: + new_w += 4 + self.latent_dims_div4 = (new_h, new_w) + # prepare mask + self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number) + # done preparing; model patches will take care of everything now. + # return normal controlnet stuff + return control_prev + + def get_models(self): + to_return: list = super().get_models() + to_return.append(self.control_model_wrapped) + return to_return + + def cleanup_advanced(self): + super().cleanup_advanced() + self.patch_attn1.cleanup() + self.patch_attn2.cleanup() + self.latent_dims_div2 = None + self.latent_dims_div4 = None + + def copy(self): + c = ControlLLLiteAdvanced(self.patch_attn1, self.patch_attn2, self.timestep_keyframes, self.device, self.ops) + self.copy_to(c) + self.copy_to_advanced(c) + return c + + # deepcopy needs to properly keep track of objects to work between model.clone calls! + # def __deepcopy__(self, *args, **kwargs): + # self.cleanup_advanced() + # return self + + # def get_models(self): + # # get_models is called once at the start of every KSampler run - use to reset already_patched status + # out = super().get_models() + # logger.error(f"in get_models! {id(self)}") + # return out + + +def load_controllllite(ckpt_path: str, controlnet_data: dict[str, Tensor]=None, timestep_keyframe: TimestepKeyframeGroup=None): + if controlnet_data is None: + controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) + # adapted from https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI + # first, split weights for each module + module_weights = {} + for key, value in controlnet_data.items(): + fragments = key.split(".") + module_name = fragments[0] + weight_name = ".".join(fragments[1:]) + + if module_name not in module_weights: + module_weights[module_name] = {} + module_weights[module_name][weight_name] = value + + unet_dtype = comfy.model_management.unet_dtype() + load_device = comfy.model_management.get_torch_device() + manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) + ops = comfy.ops.disable_weight_init + if manual_cast_dtype is not None: + ops = comfy.ops.manual_cast + + # next, load each module + modules = {} + for module_name, weights in module_weights.items(): + # kohya planned to do something about how these should be chosen, so I'm not touching this + # since I am not familiar with the logic for this + if "conditioning1.4.weight" in weights: + depth = 3 + elif weights["conditioning1.2.weight"].shape[-1] == 4: + depth = 2 + else: + depth = 1 + + module = LLLiteModule( + name=module_name, + is_conv2d=weights["down.0.weight"].ndim == 4, + in_dim=weights["down.0.weight"].shape[1], + depth=depth, + cond_emb_dim=weights["conditioning1.0.weight"].shape[0] * 2, + mlp_dim=weights["down.0.weight"].shape[0], + ) + # load weights into module + module.load_state_dict(weights) + modules[module_name] = module.to(dtype=unet_dtype) + if len(modules) == 1: + module.is_first = True + + #logger.info(f"loaded {ckpt_path} successfully, {len(modules)} modules") + + patch_attn1 = LLLitePatch(modules=modules, patch_type=LLLitePatch.ATTN1) + patch_attn2 = LLLitePatch(modules=modules, patch_type=LLLitePatch.ATTN2) + control = ControlLLLiteAdvanced(patch_attn1=patch_attn1, patch_attn2=patch_attn2, timestep_keyframes=timestep_keyframe, device=load_device, ops=ops) + return control diff --git a/adv_control/control_reference.py b/adv_control/control_reference.py index 21878df..28f1105 100644 --- a/adv_control/control_reference.py +++ b/adv_control/control_reference.py @@ -14,135 +14,7 @@ from .logger import logger from .utils import (AdvancedControlBase, ControlWeights, TimestepKeyframeGroup, AbstractPreprocWrapper, - deepcopy_with_sharing, prepare_mask_batch, broadcast_image_to_extend) - - -def refcn_sample_factory(orig_comfy_sample: Callable, is_custom=False) -> Callable: - def get_refcn(control: ControlBase, order: int=-1): - ref_set: set[ReferenceAdvanced] = set() - if control is None: - return ref_set - if type(control) == ReferenceAdvanced: - control.order = order - order -= 1 - ref_set.add(control) - ref_set.update(get_refcn(control.previous_controlnet, order=order)) - return ref_set - - def refcn_sample(model: ModelPatcher, *args, **kwargs): - # check if positive or negative conds contain ref cn - positive = args[-3] - negative = args[-2] - ref_set = set() - if positive is not None: - for cond in positive: - if "control" in cond[1]: - ref_set.update(get_refcn(cond[1]["control"])) - if negative is not None: - for cond in negative: - if "control" in cond[1]: - ref_set.update(get_refcn(cond[1]["control"])) - # if no ref cn found, do original function immediately - if len(ref_set) == 0: - return orig_comfy_sample(model, *args, **kwargs) - # otherwise, injection time - try: - # inject - # storage for all Reference-related injections - reference_injections = ReferenceInjections() - - # first, handle attn module injection - all_modules = torch_dfs(model.model) - attn_modules: list[RefBasicTransformerBlock] = [] - for module in all_modules: - if isinstance(module, BasicTransformerBlock): - attn_modules.append(module) - attn_modules = [module for module in all_modules if isinstance(module, BasicTransformerBlock)] - attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0]) - for i, module in enumerate(attn_modules): - injection_holder = InjectionBasicTransformerBlockHolder(block=module, idx=i) - injection_holder.attn_weight = float(i) / float(len(attn_modules)) - if hasattr(module, "_forward"): # backward compatibility - module._forward = _forward_inject_BasicTransformerBlock.__get__(module, type(module)) - else: - module.forward = _forward_inject_BasicTransformerBlock.__get__(module, type(module)) - module.injection_holder = injection_holder - reference_injections.attn_modules.append(module) - # figure out which module is middle block - if hasattr(model.model.diffusion_model, "middle_block"): - mid_modules = torch_dfs(model.model.diffusion_model.middle_block) - mid_attn_modules: list[RefBasicTransformerBlock] = [module for module in mid_modules if isinstance(module, BasicTransformerBlock)] - for module in mid_attn_modules: - module.injection_holder.is_middle = True - - # next, handle gn module injection (TimestepEmbedSequential) - # TODO: figure out the logic behind these hardcoded indexes - if type(model.model).__name__ == "SDXL": - input_block_indices = [4, 5, 7, 8] - output_block_indices = [0, 1, 2, 3, 4, 5] - else: - input_block_indices = [4, 5, 7, 8, 10, 11] - output_block_indices = [0, 1, 2, 3, 4, 5, 6, 7] - if hasattr(model.model.diffusion_model, "middle_block"): - module = model.model.diffusion_model.middle_block - injection_holder = InjectionTimestepEmbedSequentialHolder(block=module, idx=0, is_middle=True) - injection_holder.gn_weight = 0.0 - module.injection_holder = injection_holder - reference_injections.gn_modules.append(module) - for w, i in enumerate(input_block_indices): - module = model.model.diffusion_model.input_blocks[i] - injection_holder = InjectionTimestepEmbedSequentialHolder(block=module, idx=i, is_input=True) - injection_holder.gn_weight = 1.0 - float(w) / float(len(input_block_indices)) - module.injection_holder = injection_holder - reference_injections.gn_modules.append(module) - for w, i in enumerate(output_block_indices): - module = model.model.diffusion_model.output_blocks[i] - injection_holder = InjectionTimestepEmbedSequentialHolder(block=module, idx=i, is_output=True) - injection_holder.gn_weight = float(w) / float(len(output_block_indices)) - module.injection_holder = injection_holder - reference_injections.gn_modules.append(module) - # hack gn_module forwards and update weights - for i, module in enumerate(reference_injections.gn_modules): - module.injection_holder.gn_weight *= 2 - - # handle diffusion_model forward injection - reference_injections.diffusion_model_orig_forward = model.model.diffusion_model.forward - model.model.diffusion_model.forward = factory_forward_inject_UNetModel(reference_injections).__get__(model.model.diffusion_model, type(model.model.diffusion_model)) - # store ordered ref cns in model's transformer options - orig_model_options = model.model_options - new_model_options = model.model_options.copy() - new_model_options["transformer_options"] = model.model_options["transformer_options"].copy() - ref_list: list[ReferenceAdvanced] = list(ref_set) - new_model_options["transformer_options"][REF_CONTROL_LIST_ALL] = sorted(ref_list, key=lambda x: x.order) - model.model_options = new_model_options - # continue with original function - return orig_comfy_sample(model, *args, **kwargs) - finally: - # cleanup injections - # restore attn modules - attn_modules: list[RefBasicTransformerBlock] = reference_injections.attn_modules - for module in attn_modules: - module.injection_holder.restore(module) - module.injection_holder.clean() - del module.injection_holder - del attn_modules - # restore gn modules - gn_modules: list[RefTimestepEmbedSequential] = reference_injections.gn_modules - for module in gn_modules: - module.injection_holder.restore(module) - module.injection_holder.clean() - del module.injection_holder - del gn_modules - # restore diffusion_model forward function - model.model.diffusion_model.forward = reference_injections.diffusion_model_orig_forward.__get__(model.model.diffusion_model, type(model.model.diffusion_model)) - # restore model_options - model.model_options = orig_model_options - # cleanup - reference_injections.cleanup() - return refcn_sample -# inject sample functions -comfy.sample.sample = refcn_sample_factory(comfy.sample.sample) -comfy.sample.sample_custom = refcn_sample_factory(comfy.sample.sample_custom, is_custom=True) + broadcast_image_to_extend) REF_ATTN_CONTROL_LIST = "ref_attn_control_list" @@ -826,10 +698,3 @@ def forward_timestep_embed_ref_inject(*args, **kwargs): return y.to(x.dtype) return forward_timestep_embed_ref_inject - -# DFS Search for Torch.nn.Module, Written by Lvmin -def torch_dfs(model: torch.nn.Module): - result = [model] - for child in model.children(): - result += torch_dfs(child) - return result diff --git a/adv_control/nodes.py b/adv_control/nodes.py index 6f79056..b39fbfb 100644 --- a/adv_control/nodes.py +++ b/adv_control/nodes.py @@ -2,6 +2,7 @@ from torch import Tensor import folder_paths +import comfy.sample from comfy.model_patcher import ModelPatcher from .control import load_controlnet, convert_to_advanced, is_advanced_controlnet, is_sd3_advanced_controlnet @@ -17,6 +18,11 @@ from .nodes_deprecated import LoadImagesFromDirectory from .logger import logger +from .sampling import acn_sample_factory +# inject sample functions +comfy.sample.sample = acn_sample_factory(comfy.sample.sample) +comfy.sample.sample_custom = acn_sample_factory(comfy.sample.sample_custom, is_custom=True) + class ControlNetLoaderAdvanced: @classmethod diff --git a/adv_control/sampling.py b/adv_control/sampling.py new file mode 100644 index 0000000..81f0407 --- /dev/null +++ b/adv_control/sampling.py @@ -0,0 +1,191 @@ +from typing import Callable, Union + +import comfy.sample +from comfy.model_patcher import ModelPatcher +from comfy.controlnet import ControlBase +from comfy.ldm.modules.attention import BasicTransformerBlock + + +from .control import convert_all_to_advanced, restore_all_controlnet_conns +from .control_reference import (ReferenceAdvanced, ReferenceInjections, + RefBasicTransformerBlock, RefTimestepEmbedSequential, + InjectionBasicTransformerBlockHolder, InjectionTimestepEmbedSequentialHolder, + _forward_inject_BasicTransformerBlock, factory_forward_inject_UNetModel, + REF_CONTROL_LIST_ALL) +from .control_lllite import (ControlLLLiteAdvanced) +from .utils import torch_dfs + + +def support_sliding_context_windows(model, positive, negative) -> tuple[bool, dict, dict]: + if not hasattr(model, "motion_injection_params"): + return False, positive, negative + motion_injection_params = getattr(model, "motion_injection_params") + context_options = getattr(motion_injection_params, "context_options") + if context_options.context_length is None: + return False, positive, negative + # convert to advanced, with report if anything was actually modified + modified, new_conds = convert_all_to_advanced([positive, negative]) + positive, negative = new_conds + return modified, positive, negative + + +def acn_sample_factory(orig_comfy_sample: Callable, is_custom=False) -> Callable: + def get_refcn(control: ControlBase, order: int=-1): + ref_set: set[ReferenceAdvanced] = set() + if control is None: + return ref_set + if type(control) == ReferenceAdvanced: + control.order = order + order -= 1 + ref_set.add(control) + ref_set.update(get_refcn(control.previous_controlnet, order=order)) + return ref_set + + def get_lllitecn(control: ControlBase): + cn_dict: dict[ControlLLLiteAdvanced,None] = {} + if control is None: + return cn_dict + if type(control) == ControlLLLiteAdvanced: + cn_dict[control] = None + cn_dict.update(get_lllitecn(control.previous_controlnet)) + return cn_dict + + def acn_sample(model: ModelPatcher, *args, **kwargs): + controlnets_modified = False + orig_positive = args[-3] + orig_negative = args[-2] + try: + orig_model_options = model.model_options + # check if positive or negative conds contain ref cn + positive = args[-3] + negative = args[-2] + # if context options present, convert all CNs to Advanced if needed + controlnets_modified, positive, negative = support_sliding_context_windows(model, positive, negative) + if controlnets_modified: + args = list(args) + args[-3] = positive + args[-2] = negative + args = tuple(args) + # look for Advanced ControlNets that will require intervention to work + ref_set = set() + lllite_dict: dict[ControlLLLiteAdvanced, None] = {} # dicts preserve insertion order since py3.7 + if positive is not None: + for cond in positive: + if "control" in cond[1]: + ref_set.update(get_refcn(cond[1]["control"])) + lllite_dict.update(get_lllitecn(cond[1]["control"])) + if negative is not None: + for cond in negative: + if "control" in cond[1]: + ref_set.update(get_refcn(cond[1]["control"])) + lllite_dict.update(get_lllitecn(cond[1]["control"])) + # if lllite found, apply patches to a cloned model_options, and continue + if len(lllite_dict) > 0: + lllite_list = list(lllite_dict.keys()) + model.model_options = model.model_options.copy() + model.model_options["transformer_options"] = model.model_options["transformer_options"].copy() + lllite_list.reverse() # reverse so that patches will be applied in expected order + for lll in lllite_list: + lll.live_model_patches(model.model_options) + # if no ref cn found, do original function immediately + if len(ref_set) == 0: + return orig_comfy_sample(model, *args, **kwargs) + # otherwise, injection time + try: + # inject + # storage for all Reference-related injections + reference_injections = ReferenceInjections() + + # first, handle attn module injection + all_modules = torch_dfs(model.model) + attn_modules: list[RefBasicTransformerBlock] = [] + for module in all_modules: + if isinstance(module, BasicTransformerBlock): + attn_modules.append(module) + attn_modules = [module for module in all_modules if isinstance(module, BasicTransformerBlock)] + attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0]) + for i, module in enumerate(attn_modules): + injection_holder = InjectionBasicTransformerBlockHolder(block=module, idx=i) + injection_holder.attn_weight = float(i) / float(len(attn_modules)) + if hasattr(module, "_forward"): # backward compatibility + module._forward = _forward_inject_BasicTransformerBlock.__get__(module, type(module)) + else: + module.forward = _forward_inject_BasicTransformerBlock.__get__(module, type(module)) + module.injection_holder = injection_holder + reference_injections.attn_modules.append(module) + # figure out which module is middle block + if hasattr(model.model.diffusion_model, "middle_block"): + mid_modules = torch_dfs(model.model.diffusion_model.middle_block) + mid_attn_modules: list[RefBasicTransformerBlock] = [module for module in mid_modules if isinstance(module, BasicTransformerBlock)] + for module in mid_attn_modules: + module.injection_holder.is_middle = True + + # next, handle gn module injection (TimestepEmbedSequential) + # TODO: figure out the logic behind these hardcoded indexes + if type(model.model).__name__ == "SDXL": + input_block_indices = [4, 5, 7, 8] + output_block_indices = [0, 1, 2, 3, 4, 5] + else: + input_block_indices = [4, 5, 7, 8, 10, 11] + output_block_indices = [0, 1, 2, 3, 4, 5, 6, 7] + if hasattr(model.model.diffusion_model, "middle_block"): + module = model.model.diffusion_model.middle_block + injection_holder = InjectionTimestepEmbedSequentialHolder(block=module, idx=0, is_middle=True) + injection_holder.gn_weight = 0.0 + module.injection_holder = injection_holder + reference_injections.gn_modules.append(module) + for w, i in enumerate(input_block_indices): + module = model.model.diffusion_model.input_blocks[i] + injection_holder = InjectionTimestepEmbedSequentialHolder(block=module, idx=i, is_input=True) + injection_holder.gn_weight = 1.0 - float(w) / float(len(input_block_indices)) + module.injection_holder = injection_holder + reference_injections.gn_modules.append(module) + for w, i in enumerate(output_block_indices): + module = model.model.diffusion_model.output_blocks[i] + injection_holder = InjectionTimestepEmbedSequentialHolder(block=module, idx=i, is_output=True) + injection_holder.gn_weight = float(w) / float(len(output_block_indices)) + module.injection_holder = injection_holder + reference_injections.gn_modules.append(module) + # hack gn_module forwards and update weights + for i, module in enumerate(reference_injections.gn_modules): + module.injection_holder.gn_weight *= 2 + + # handle diffusion_model forward injection + reference_injections.diffusion_model_orig_forward = model.model.diffusion_model.forward + model.model.diffusion_model.forward = factory_forward_inject_UNetModel(reference_injections).__get__(model.model.diffusion_model, type(model.model.diffusion_model)) + # store ordered ref cns in model's transformer options + new_model_options = model.model_options.copy() + new_model_options["transformer_options"] = model.model_options["transformer_options"].copy() + ref_list: list[ReferenceAdvanced] = list(ref_set) + new_model_options["transformer_options"][REF_CONTROL_LIST_ALL] = sorted(ref_list, key=lambda x: x.order) + model.model_options = new_model_options + # continue with original function + return orig_comfy_sample(model, *args, **kwargs) + finally: + # cleanup injections + # restore attn modules + attn_modules: list[RefBasicTransformerBlock] = reference_injections.attn_modules + for module in attn_modules: + module.injection_holder.restore(module) + module.injection_holder.clean() + del module.injection_holder + del attn_modules + # restore gn modules + gn_modules: list[RefTimestepEmbedSequential] = reference_injections.gn_modules + for module in gn_modules: + module.injection_holder.restore(module) + module.injection_holder.clean() + del module.injection_holder + del gn_modules + # restore diffusion_model forward function + model.model.diffusion_model.forward = reference_injections.diffusion_model_orig_forward.__get__(model.model.diffusion_model, type(model.model.diffusion_model)) + # cleanup + reference_injections.cleanup() + finally: + # restore model_options + model.model_options = orig_model_options + # restore controlnets in conds, if needed + if controlnets_modified: + restore_all_controlnet_conns([orig_positive, orig_negative]) + + return acn_sample diff --git a/adv_control/utils.py b/adv_control/utils.py index 726233b..f385133 100644 --- a/adv_control/utils.py +++ b/adv_control/utils.py @@ -540,6 +540,14 @@ def get_sorted_list_via_attr(objects: list, attr: str) -> list: return sorted_list +# DFS Search for Torch.nn.Module, Written by Lvmin +def torch_dfs(model: torch.nn.Module): + result = [model] + for child in model.children(): + result += torch_dfs(child) + return result + + class WeightTypeException(TypeError): "Raised when weight not compatible with AdvancedControlBase object" pass diff --git a/pyproject.toml b/pyproject.toml index 5662632..28657f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "comfyui-advanced-controlnet" description = "Nodes for scheduling ControlNet strength across timesteps and batched latents, as well as applying custom weights and attention masks." -version = "1.1.3" +version = "1.1.4" license = "LICENSE" dependencies = []