Skip to content

Commit

Permalink
Updated clone_has_same_weights function to account for new ModelPatch…
Browse files Browse the repository at this point in the history
…er properties, improved AutoPatcherEjector usage in partially_load
  • Loading branch information
Kosinkadink committed Sep 21, 2024
1 parent f28d892 commit 298397d
Showing 1 changed file with 23 additions and 12 deletions.
35 changes: 23 additions & 12 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,28 +163,27 @@ def new_executor(cls, original: Callable, wrappers: List[Callable]):
return cls(original, wrappers, idx=0)

class AutoPatcherEjector:
def __init__(self, model: 'ModelPatcher', skip_until_exit=False):
def __init__(self, model: 'ModelPatcher', skip_and_inject_on_exit_only=False):
self.model = model
self.was_injected = False
self.prev_skip_injection = False
self.skip_until_exit = skip_until_exit
self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only

def __enter__(self):
self.was_injected = False
self.prev_skip_injection = self.model.skip_injection
if self.skip_until_exit:
if self.skip_and_inject_on_exit_only:
self.model.skip_injection = True
if self.model.is_injected:
self.model.eject_model()
self.was_injected = True

def __exit__(self, *args):
if self.was_injected:
if self.skip_until_exit:
self.model.skip_injection = self.prev_skip_injection
self.model.inject_model()
elif not self.model.skip_injection:
self.model.inject_model()
if self.skip_and_inject_on_exit_only:
self.model.skip_injection = self.prev_skip_injection
self.model.inject_model()
if self.was_injected and not self.model.skip_injection:
self.model.inject_model()
self.model.skip_injection = self.prev_skip_injection

class PatcherInjection:
Expand Down Expand Up @@ -319,6 +318,18 @@ def clone_has_same_weights(self, clone: 'ModelPatcher'):
return False
if self.hook_patches.keys() != clone.hook_patches.keys():
return False
if self.attachments.keys() != clone.attachments.keys():
return False
if self.additional_models.keys() != clone.additional_models.keys():
return False
for key in self.callbacks:
if len(self.callbacks[key]) != len(clone.callbacks[key]):
return False
for key in self.wrappers:
if len(self.wrappers[key]) != len(clone.wrappers[key]):
return False
if self.injections.keys() != clone.injections.keys():
return False

if len(self.patches) == 0 and len(clone.patches) == 0:
return True
Expand Down Expand Up @@ -700,7 +711,7 @@ def partially_unload(self, device_to, memory_to_free=0):
return memory_freed

def partially_load(self, device_to, extra_memory=0):
with self.use_ejected(skip_injection=True):
with self.use_ejected(skip_and_inject_on_exit_only=True):
self.unpatch_model(unpatch_weights=False)
self.patch_model(load_weights=False)
full_load = False
Expand Down Expand Up @@ -755,8 +766,8 @@ def set_injections(self, key: str, injections: List[PatcherInjection]):
def set_additional_models(self, key: str, models: List['ModelPatcher']):
self.additional_models[key] = models

def use_ejected(self, skip_injection=False):
return AutoPatcherEjector(self, skip_until_exit=skip_injection)
def use_ejected(self, skip_and_inject_on_exit_only=False):
return AutoPatcherEjector(self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only)

def inject_model(self):
if self.is_injected or self.skip_injection:
Expand Down

0 comments on commit 298397d

Please sign in to comment.