From 5f450d3351916230e32da64640dfa0b188604ba4 Mon Sep 17 00:00:00 2001 From: "kosinkadink1@gmail.com" Date: Sat, 21 Sep 2024 10:37:18 +0900 Subject: [PATCH] Started scaffolding for other hook types, refactored get_hooks_from_cond to organize hooks by type --- comfy/hooks.py | 101 +++++++++++++++++++++++++++++++++++---- comfy/model_patcher.py | 9 ++-- comfy/sampler_helpers.py | 17 ++++--- 3 files changed, 103 insertions(+), 24 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index eb58509f27b..980d16ed50d 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -5,7 +5,7 @@ import numpy as np if TYPE_CHECKING: - from comfy.model_patcher import ModelPatcher + from comfy.model_patcher import ModelPatcher, PatcherInjection from comfy.model_base import BaseModel from comfy.sd import CLIP import comfy.lora @@ -19,7 +19,11 @@ class EnumHookMode(enum.Enum): class EnumHookType(enum.Enum): Weight = "weight" Patch = "patch" - AddModel = "addmodel" + ObjectPatch = "object_patch" + AddModels = "add_models" + AddCallback = "add_callback" + SetInjections = "add_injections" + AddWrapper = "add_wrapper" class EnumWeightTarget(enum.Enum): Model = "model" @@ -125,18 +129,94 @@ def clone(self, subtype: Callable=None): c: PatchHook = super().clone(subtype) c.patches = self.patches return c + + def add_hook_patches(self, model: 'ModelPatcher'): + pass -class AddModelHook(Hook): - def __init__(self, model: 'ModelPatcher'): - super().__init__(hook_type=EnumHookType.AddModel) - self.model = model +class ObjectPatchHook(Hook): + def __init__(self): + super().__init__(hook_type=EnumHookType.ObjectPatch) + self.object_patches: Dict = None + + def clone(self, subtype: Callable=None): + if subtype is None: + subtype = type(self) + c: ObjectPatchHook = super().clone(subtype) + c.object_patches = self.object_patches + return c + + def add_hook_object_patches(self, model: 'ModelPatcher'): + pass + +class AddModelsHook(Hook): + def __init__(self, key: str=None, models: List['ModelPatcher']=None): + super().__init__(hook_type=EnumHookType.AddModels) + self.key = key + self.models = models + self.append_when_same = True def clone(self, subtype: Callable=None): if subtype is None: subtype = type(self) - c: AddModelHook = super().clone(subtype) - c.model = self.model + c: AddModelsHook = super().clone(subtype) + c.key = self.key + c.models = self.models.copy() if self.models else self.models + c.append_when_same = self.append_when_same + return c + + def add_hook_models(self, model: 'ModelPatcher'): + pass + +class AddCallbackHook(Hook): + def __init__(self, key: str=None, callback: Callable=None): + super().__init__(hook_type=EnumHookType.AddCallback) + self.key = key + self.callback = callback + + def clone(self, subtype: Callable=None): + if subtype is None: + subtype = type(self) + c: AddCallbackHook = super().clone(subtype) + c.key = self.key + c.callback = self.callback return c + + def add_hook_callback(self, model: 'ModelPatcher'): + pass + +class SetInjectionsHook(Hook): + def __init__(self, key: str=None, injections: List['PatcherInjection']=None): + super().__init__(hook_type=EnumHookType.SetInjections) + self.key = key + self.injections = injections + + def clone(self, subtype: Callable=None): + if subtype is None: + subtype = type(self) + c: SetInjectionsHook = super().clone(subtype) + c.key = self.key + c.injections = self.injections.copy() if self.injections else self.injections + return c + + def add_hook_injections(self, model: 'ModelPatcher'): + pass + +class AddWrapperHook(Hook): + def __init__(self, key: str=None, wrapper: Callable=None): + super().__init__(hook_type=EnumHookType.AddWrapper) + self.key = key + self.wrapper = wrapper + + def clone(self, subtype: Callable=None): + if subtype is None: + subtype = type(self) + c: AddWrapperHook = super().clone(subtype) + c.key = self.key + c.wrapper = self.wrapper + return c + + def add_hook_wrapper(self, model: 'ModelPatcher'): + pass class HookGroup: def __init__(self): @@ -167,9 +247,10 @@ def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'): hook.hook_keyframe = hook_kf def get_dict_repr(self): - d = {} + d: Dict[EnumHookType, Dict[Hook, None]] = {} for hook in self.hooks: - d[hook] = None + with_type = d.setdefault(hook.hook_type, {}) + with_type[hook] = None return d @staticmethod diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 2541c299aed..f5dc81b33f5 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -809,13 +809,12 @@ def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: com if cached_group.contains(hook): self.cached_hook_patches.pop(cached_group) - def register_all_hook_patches(self, hooks_dict: Dict[comfy.hooks.Hook, None], target: comfy.hooks.EnumWeightTarget): + def register_all_hook_patches(self, hooks_dict: Dict[comfy.hooks.EnumHookType, Dict[comfy.hooks.Hook, None]], target: comfy.hooks.EnumWeightTarget): self.restore_hook_patches() weight_hooks_to_register: List[comfy.hooks.WeightHook] = [] - for hook in hooks_dict: - if hook.hook_type == comfy.hooks.EnumHookType.Weight: - if hook.hook_ref not in self.hook_patches: - weight_hooks_to_register.append(hook) + for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}): + if hook.hook_ref not in self.hook_patches: + weight_hooks_to_register.append(hook) if len(weight_hooks_to_register) > 0: self.hook_patches_backup = create_hook_patches_clone(self.hook_patches) for hook in weight_hooks_to_register: diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index af89a715a96..36325cfa2ec 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -26,15 +26,14 @@ def get_models_from_cond(cond, model_type): models += [c[model_type]] return models -def get_hooks_from_cond(cond, filter_types: List[comfy.hooks.EnumHookType]=None): - hooks: Dict[comfy.hooks.Hook, None] = {} +def get_hooks_from_cond(cond, hooks_dict: Dict[comfy.hooks.EnumHookType, Dict[comfy.hooks.Hook, None]]): for c in cond: if 'hooks' in c: for hook in c['hooks'].hooks: hook: comfy.hooks.Hook - if not filter_types or hook.hook_type in filter_types: - hooks[hook] = None - return hooks + with_type = hooks_dict.setdefault(hook.hook_type, {}) + with_type[hook] = None + return hooks_dict def convert_cond(cond): out = [] @@ -53,13 +52,13 @@ def get_additional_models(conds, dtype): cnets: List[ControlBase] = [] gligen = [] add_models = [] - hooks: Dict[comfy.hooks.AddModelHook, None] = {} + hooks: Dict[comfy.hooks.EnumHookType, Dict[comfy.hooks.Hook, None]] = {} for k in conds: cnets += get_models_from_cond(conds[k], "control") gligen += get_models_from_cond(conds[k], "gligen") add_models += get_models_from_cond(conds[k], "additional_models") - hooks.update(get_hooks_from_cond(conds[k], [comfy.hooks.EnumHookType.AddModel])) + get_hooks_from_cond(conds[k], hooks) control_nets = set(cnets) @@ -70,7 +69,7 @@ def get_additional_models(conds, dtype): inference_memory += m.inference_memory_requirements(dtype) gligen = [x[1] for x in gligen] - hook_models = [x.model for x in hooks] + hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()] models = control_models + gligen + add_models + hook_models return models, inference_memory @@ -108,5 +107,5 @@ def prepare_model_patcher(model: 'ModelPatcher', conds): # check for hooks in conds - if not registered, see if can be applied hooks = {} for k in conds: - hooks.update(get_hooks_from_cond(conds[k])) + get_hooks_from_cond(conds[k], hooks) model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model)