Skip to content

Commit

Permalink
Started scaffolding for other hook types, refactored get_hooks_from_c…
Browse files Browse the repository at this point in the history
…ond to organize hooks by type
  • Loading branch information
Kosinkadink committed Sep 21, 2024
1 parent 59d72b4 commit 5f450d3
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 24 deletions.
101 changes: 91 additions & 10 deletions comfy/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np

if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_patcher import ModelPatcher, PatcherInjection
from comfy.model_base import BaseModel
from comfy.sd import CLIP
import comfy.lora
Expand All @@ -19,7 +19,11 @@ class EnumHookMode(enum.Enum):
class EnumHookType(enum.Enum):
Weight = "weight"
Patch = "patch"
AddModel = "addmodel"
ObjectPatch = "object_patch"
AddModels = "add_models"
AddCallback = "add_callback"
SetInjections = "add_injections"
AddWrapper = "add_wrapper"

class EnumWeightTarget(enum.Enum):
Model = "model"
Expand Down Expand Up @@ -125,18 +129,94 @@ def clone(self, subtype: Callable=None):
c: PatchHook = super().clone(subtype)
c.patches = self.patches
return c

def add_hook_patches(self, model: 'ModelPatcher'):
pass

class AddModelHook(Hook):
def __init__(self, model: 'ModelPatcher'):
super().__init__(hook_type=EnumHookType.AddModel)
self.model = model
class ObjectPatchHook(Hook):
def __init__(self):
super().__init__(hook_type=EnumHookType.ObjectPatch)
self.object_patches: Dict = None

def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: ObjectPatchHook = super().clone(subtype)
c.object_patches = self.object_patches
return c

def add_hook_object_patches(self, model: 'ModelPatcher'):
pass

class AddModelsHook(Hook):
def __init__(self, key: str=None, models: List['ModelPatcher']=None):
super().__init__(hook_type=EnumHookType.AddModels)
self.key = key
self.models = models
self.append_when_same = True

def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: AddModelHook = super().clone(subtype)
c.model = self.model
c: AddModelsHook = super().clone(subtype)
c.key = self.key
c.models = self.models.copy() if self.models else self.models
c.append_when_same = self.append_when_same
return c

def add_hook_models(self, model: 'ModelPatcher'):
pass

class AddCallbackHook(Hook):
def __init__(self, key: str=None, callback: Callable=None):
super().__init__(hook_type=EnumHookType.AddCallback)
self.key = key
self.callback = callback

def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: AddCallbackHook = super().clone(subtype)
c.key = self.key
c.callback = self.callback
return c

def add_hook_callback(self, model: 'ModelPatcher'):
pass

class SetInjectionsHook(Hook):
def __init__(self, key: str=None, injections: List['PatcherInjection']=None):
super().__init__(hook_type=EnumHookType.SetInjections)
self.key = key
self.injections = injections

def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: SetInjectionsHook = super().clone(subtype)
c.key = self.key
c.injections = self.injections.copy() if self.injections else self.injections
return c

def add_hook_injections(self, model: 'ModelPatcher'):
pass

class AddWrapperHook(Hook):
def __init__(self, key: str=None, wrapper: Callable=None):
super().__init__(hook_type=EnumHookType.AddWrapper)
self.key = key
self.wrapper = wrapper

def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: AddWrapperHook = super().clone(subtype)
c.key = self.key
c.wrapper = self.wrapper
return c

def add_hook_wrapper(self, model: 'ModelPatcher'):
pass

class HookGroup:
def __init__(self):
Expand Down Expand Up @@ -167,9 +247,10 @@ def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'):
hook.hook_keyframe = hook_kf

def get_dict_repr(self):
d = {}
d: Dict[EnumHookType, Dict[Hook, None]] = {}
for hook in self.hooks:
d[hook] = None
with_type = d.setdefault(hook.hook_type, {})
with_type[hook] = None
return d

@staticmethod
Expand Down
9 changes: 4 additions & 5 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,13 +809,12 @@ def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: com
if cached_group.contains(hook):
self.cached_hook_patches.pop(cached_group)

def register_all_hook_patches(self, hooks_dict: Dict[comfy.hooks.Hook, None], target: comfy.hooks.EnumWeightTarget):
def register_all_hook_patches(self, hooks_dict: Dict[comfy.hooks.EnumHookType, Dict[comfy.hooks.Hook, None]], target: comfy.hooks.EnumWeightTarget):
self.restore_hook_patches()
weight_hooks_to_register: List[comfy.hooks.WeightHook] = []
for hook in hooks_dict:
if hook.hook_type == comfy.hooks.EnumHookType.Weight:
if hook.hook_ref not in self.hook_patches:
weight_hooks_to_register.append(hook)
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}):
if hook.hook_ref not in self.hook_patches:
weight_hooks_to_register.append(hook)
if len(weight_hooks_to_register) > 0:
self.hook_patches_backup = create_hook_patches_clone(self.hook_patches)
for hook in weight_hooks_to_register:
Expand Down
17 changes: 8 additions & 9 deletions comfy/sampler_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,14 @@ def get_models_from_cond(cond, model_type):
models += [c[model_type]]
return models

def get_hooks_from_cond(cond, filter_types: List[comfy.hooks.EnumHookType]=None):
hooks: Dict[comfy.hooks.Hook, None] = {}
def get_hooks_from_cond(cond, hooks_dict: Dict[comfy.hooks.EnumHookType, Dict[comfy.hooks.Hook, None]]):
for c in cond:
if 'hooks' in c:
for hook in c['hooks'].hooks:
hook: comfy.hooks.Hook
if not filter_types or hook.hook_type in filter_types:
hooks[hook] = None
return hooks
with_type = hooks_dict.setdefault(hook.hook_type, {})
with_type[hook] = None
return hooks_dict

def convert_cond(cond):
out = []
Expand All @@ -53,13 +52,13 @@ def get_additional_models(conds, dtype):
cnets: List[ControlBase] = []
gligen = []
add_models = []
hooks: Dict[comfy.hooks.AddModelHook, None] = {}
hooks: Dict[comfy.hooks.EnumHookType, Dict[comfy.hooks.Hook, None]] = {}

for k in conds:
cnets += get_models_from_cond(conds[k], "control")
gligen += get_models_from_cond(conds[k], "gligen")
add_models += get_models_from_cond(conds[k], "additional_models")
hooks.update(get_hooks_from_cond(conds[k], [comfy.hooks.EnumHookType.AddModel]))
get_hooks_from_cond(conds[k], hooks)

control_nets = set(cnets)

Expand All @@ -70,7 +69,7 @@ def get_additional_models(conds, dtype):
inference_memory += m.inference_memory_requirements(dtype)

gligen = [x[1] for x in gligen]
hook_models = [x.model for x in hooks]
hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()]
models = control_models + gligen + add_models + hook_models

return models, inference_memory
Expand Down Expand Up @@ -108,5 +107,5 @@ def prepare_model_patcher(model: 'ModelPatcher', conds):
# check for hooks in conds - if not registered, see if can be applied
hooks = {}
for k in conds:
hooks.update(get_hooks_from_cond(conds[k]))
get_hooks_from_cond(conds[k], hooks)
model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model)

0 comments on commit 5f450d3

Please sign in to comment.