Skip to content

Commit

Permalink
Added wrappers to ModelPatcher to facilitate standardized function wr…
Browse files Browse the repository at this point in the history
…apping
  • Loading branch information
Kosinkadink committed Sep 20, 2024
1 parent 5501429 commit 59d72b4
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 23 deletions.
9 changes: 6 additions & 3 deletions comfy/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,15 +173,18 @@ def get_dict_repr(self):
return d

@staticmethod
def combine_all_hooks(hooks_list: List['HookGroup'], require_count=1) -> 'HookGroup':
def combine_all_hooks(hooks_list: List['HookGroup'], require_count=0) -> 'HookGroup':
actual: List[HookGroup] = []
for group in hooks_list:
if group is not None:
actual.append(group)
if len(actual) < require_count:
raise Exception(f"Need at least {require_count} hooks to combine, but only had {len(actual)}.")
# if only 1 hook, just reutnr itself without cloning
if len(actual) == 1:
# if no hooks, then return None
if len(actual) == 0:
return None
# if only 1 hook, just return itself without cloning
elif len(actual) == 1:
return actual[0]
final_hook: HookGroup = None
for hook in actual:
Expand Down
72 changes: 62 additions & 10 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,43 @@ def init_callbacks(cls):
cls.ON_EJECT_MODEL: [],
}

class WrappersMP:
OUTER_SAMPLE = "outer_sample"

@classmethod
def init_wrappers(cls):
return {
cls.OUTER_SAMPLE: [],
}

class WrapperExecutor:
def __init__(self, original: Callable, wrappers: List[Callable], idx: int):
self.original = original
self.wrappers = wrappers.copy()
self.idx = idx
self.is_last = idx == len(wrappers)

def __call__(self, guider, *args, **kwargs):
new_executor = self._create_next_executor()
return new_executor._execute(guider, *args, **kwargs)

def _execute(self, guider, *args, **kwargs):
args = list(args)
kwargs = dict(kwargs)
if self.is_last:
return self.original(*args, **kwargs)
return self.wrappers[self.idx](self, guider, *args, **kwargs)

def _create_next_executor(self):
new_idx = self.idx + 1
if new_idx > len(self.wrappers):
raise Exception(f"Wrapper idx exceeded available wrappers; something went very wrong.")
return WrapperExecutor(self.original, self.wrappers, new_idx)

@classmethod
def new_executor(cls, original: Callable, wrappers: List[Callable]):
return cls(original, wrappers, idx=0)

class AutoPatcherEjector:
def __init__(self, model: 'ModelPatcher', skip_until_exit=False):
self.model = model
Expand Down Expand Up @@ -176,6 +213,7 @@ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_up
self.attachments: Dict[str] = {}
self.additional_models: Dict[str, List[ModelPatcher]] = {}
self.callbacks: Dict[str, List[Callable]] = CallbacksMP.init_callbacks()
self.wrappers: Dict[str, List[Callable]] = WrappersMP.init_wrappers()

self.is_injected = False
self.skip_injection = False
Expand Down Expand Up @@ -236,6 +274,9 @@ def clone(self):
# callbacks
for k, c in self.callbacks.items():
n.callbacks[k] = c.copy()
# sample wrappers
for k, w in self.wrappers.items():
n.wrappers[k] = w.copy()
# injection
n.is_injected = self.is_injected
n.skip_injection = self.skip_injection
Expand All @@ -254,7 +295,7 @@ def clone(self):
n.forced_hooks = self.forced_hooks.clone() if self.forced_hooks else self.forced_hooks
n.hook_mode = self.hook_mode

for callback in self.callbacks[CallbacksMP.ON_CLONE]:
for callback in self.get_callbacks(CallbacksMP.ON_CLONE):
callback(self, n)
return n

Expand Down Expand Up @@ -545,7 +586,7 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
self.model.device = device_to
self.model.model_loaded_weight_memory = mem_counter

for callback in self.callbacks[CallbacksMP.ON_LOAD]:
for callback in self.get_callbacks(CallbacksMP.ON_LOAD):
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)

self.apply_hooks(self.forced_hooks)
Expand Down Expand Up @@ -677,8 +718,7 @@ def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float3

def cleanup(self):
self.clean_hooks()
self.restore_hook_patches()
for callback in self.callbacks[CallbacksMP.ON_CLEANUP]:
for callback in self.get_callbacks(CallbacksMP.ON_CLEANUP):
callback(self)

def get_all_additional_models(self):
Expand All @@ -692,6 +732,17 @@ def add_callback(self, key: str, callback: Callable):
raise Exception(f"Callback '{key}' is not recognized.")
self.callbacks[key].append(callback)

def get_callbacks(self, key: str):
return self.callbacks.get(key, [])

def add_wrapper(self, key: str, wrapper: Callable):
if key not in self.wrappers:
raise Exception(f"Wrapper '{key}' is not recognized.")
self.wrappers[key].append(wrapper)

def get_wrappers(self, key: str):
return self.wrappers.get(key, [])

def set_attachments(self, key: str, attachment):
self.attachments[key] = attachment

Expand All @@ -712,7 +763,7 @@ def inject_model(self):
inj.inject(self)
self.is_injected = True
if self.is_injected:
for callback in self.callbacks[CallbacksMP.ON_INJECT_MODEL]:
for callback in self.get_callbacks(CallbacksMP.ON_INJECT_MODEL):
callback(self)

def eject_model(self):
Expand All @@ -722,15 +773,15 @@ def eject_model(self):
for inj in injections:
inj.eject(self)
self.is_injected = False
for callback in self.callbacks[CallbacksMP.ON_EJECT_MODEL]:
for callback in self.get_callbacks(CallbacksMP.ON_EJECT_MODEL):
callback(self)

def pre_run(self):
for callback in self.callbacks[CallbacksMP.ON_PRE_RUN]:
for callback in self.get_callbacks(CallbacksMP.ON_PRE_RUN):
callback(self)

def prepare_state(self, timestep):
for callback in self.callbacks[CallbacksMP.ON_PREPARE_STATE]:
for callback in self.get_callbacks(CallbacksMP.ON_PREPARE_STATE):
callback(self, timestep)

def restore_hook_patches(self):
Expand Down Expand Up @@ -769,7 +820,7 @@ def register_all_hook_patches(self, hooks_dict: Dict[comfy.hooks.Hook, None], ta
self.hook_patches_backup = create_hook_patches_clone(self.hook_patches)
for hook in weight_hooks_to_register:
hook.add_hook_patches(self, target)
for callback in self.callbacks[CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES]:
for callback in self.get_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES):
callback(self, hooks_dict, target)

def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0, is_diff=False):
Expand Down Expand Up @@ -851,7 +902,7 @@ def apply_hooks(self, hooks: comfy.hooks.HookGroup):
if self.current_hooks == hooks:
return
self.patch_hooks(hooks=hooks)
for callback in self.callbacks[CallbacksMP.ON_APPLY_HOOKS]:
for callback in self.get_callbacks(CallbacksMP.ON_APPLY_HOOKS):
callback(self, hooks)

def patch_hooks(self, hooks: comfy.hooks.HookGroup):
Expand Down Expand Up @@ -907,6 +958,7 @@ def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_pat
# TODO: properly handle lowvram situations for cached hook patches
temp_weight = comfy.model_management.cast_to_device(weight, weight.device, torch.float32, copy=True)
out_weight = comfy.lora.calculate_weight(combined_patches[key], temp_weight, key, original_weights=original_weights).to(weight.dtype)
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
self.cached_hook_patches.setdefault(hooks, {})
self.cached_hook_patches[hooks][key] = out_weight
Expand Down
32 changes: 22 additions & 10 deletions comfy/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import math
import logging
import comfy.sampler_helpers
import comfy.model_patcher
import comfy.hooks
import scipy.stats
import numpy
Expand Down Expand Up @@ -766,14 +767,7 @@ def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mas
samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return self.inner_model.process_latent_out(samples.to(torch.float32))

def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
if sigmas.shape[-1] == 0:
return latent_image

self.conds = {}
for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))

def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
device = self.model_patcher.load_device

Expand All @@ -786,17 +780,35 @@ def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callba

try:
self.model_patcher.pre_run()
comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds)
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
finally:
self.model_patcher.cleanup()

comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
del self.inner_model
del self.conds
del self.loaded_models
return output

def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
if sigmas.shape[-1] == 0:
return latent_image

self.conds = {}
for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))

try:
comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds)
executor = comfy.model_patcher.WrapperExecutor.new_executor(
self.outer_sample,
self.model_patcher.get_wrappers(comfy.model_patcher.WrappersMP.OUTER_SAMPLE))
output = executor._execute(self, noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
finally:
self.model_patcher.restore_hook_patches()

del self.conds
return output


def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
cfg_guider = CFGGuider(model)
Expand Down

0 comments on commit 59d72b4

Please sign in to comment.