diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 57d7353e1c5..7540f697bb2 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -163,28 +163,27 @@ def new_executor(cls, original: Callable, wrappers: List[Callable]): return cls(original, wrappers, idx=0) class AutoPatcherEjector: - def __init__(self, model: 'ModelPatcher', skip_until_exit=False): + def __init__(self, model: 'ModelPatcher', skip_and_inject_on_exit_only=False): self.model = model self.was_injected = False self.prev_skip_injection = False - self.skip_until_exit = skip_until_exit + self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only def __enter__(self): self.was_injected = False self.prev_skip_injection = self.model.skip_injection - if self.skip_until_exit: + if self.skip_and_inject_on_exit_only: self.model.skip_injection = True if self.model.is_injected: self.model.eject_model() self.was_injected = True def __exit__(self, *args): - if self.was_injected: - if self.skip_until_exit: - self.model.skip_injection = self.prev_skip_injection - self.model.inject_model() - elif not self.model.skip_injection: - self.model.inject_model() + if self.skip_and_inject_on_exit_only: + self.model.skip_injection = self.prev_skip_injection + self.model.inject_model() + if self.was_injected and not self.model.skip_injection: + self.model.inject_model() self.model.skip_injection = self.prev_skip_injection class PatcherInjection: @@ -319,6 +318,18 @@ def clone_has_same_weights(self, clone: 'ModelPatcher'): return False if self.hook_patches.keys() != clone.hook_patches.keys(): return False + if self.attachments.keys() != clone.attachments.keys(): + return False + if self.additional_models.keys() != clone.additional_models.keys(): + return False + for key in self.callbacks: + if len(self.callbacks[key]) != len(clone.callbacks[key]): + return False + for key in self.wrappers: + if len(self.wrappers[key]) != len(clone.wrappers[key]): + return False + if self.injections.keys() != clone.injections.keys(): + return False if len(self.patches) == 0 and len(clone.patches) == 0: return True @@ -700,7 +711,7 @@ def partially_unload(self, device_to, memory_to_free=0): return memory_freed def partially_load(self, device_to, extra_memory=0): - with self.use_ejected(skip_injection=True): + with self.use_ejected(skip_and_inject_on_exit_only=True): self.unpatch_model(unpatch_weights=False) self.patch_model(load_weights=False) full_load = False @@ -755,8 +766,8 @@ def set_injections(self, key: str, injections: List[PatcherInjection]): def set_additional_models(self, key: str, models: List['ModelPatcher']): self.additional_models[key] = models - def use_ejected(self, skip_injection=False): - return AutoPatcherEjector(self, skip_until_exit=skip_injection) + def use_ejected(self, skip_and_inject_on_exit_only=False): + return AutoPatcherEjector(self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only) def inject_model(self): if self.is_injected or self.skip_injection: