Skip to content

Commit

Permalink
Merge PR #440 from Kosinkadink/current_device_fix
Browse files Browse the repository at this point in the history
Remove current_device param from ModelPatcher init
  • Loading branch information
Kosinkadink committed Aug 6, 2024
2 parents 106d691 + c1c3bbc commit f297a20
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
8 changes: 4 additions & 4 deletions animatediff/model_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
class ModelPatcherAndInjector(ModelPatcher):
def __init__(self, m: ModelPatcher):
# replicate ModelPatcher.clone() to initialize ModelPatcherAndInjector
super().__init__(m.model, m.load_device, m.offload_device, m.size, m.current_device, weight_inplace_update=m.weight_inplace_update)
super().__init__(m.model, m.load_device, m.offload_device, m.size, weight_inplace_update=m.weight_inplace_update)
self.patches = {}
for k in m.patches:
self.patches[k] = m.patches[k][:]
Expand Down Expand Up @@ -439,7 +439,7 @@ def add_hooked_patches_as_diffs(self, lora_hook: LoraHook, patches, strength_pat
class ModelPatcherCLIPHooks(ModelPatcher):
def __init__(self, m: ModelPatcher):
# replicate ModelPatcher.clone() to initialize
super().__init__(m.model, m.load_device, m.offload_device, m.size, m.current_device, weight_inplace_update=m.weight_inplace_update)
super().__init__(m.model, m.load_device, m.offload_device, m.size, weight_inplace_update=m.weight_inplace_update)
self.patches = {}
for k in m.patches:
self.patches[k] = m.patches[k][:]
Expand Down Expand Up @@ -1016,7 +1016,7 @@ def cleanup(self):

def clone(self):
# normal ModelPatcher clone actions
n = MotionModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
n = MotionModelPatcher(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
Expand Down Expand Up @@ -1124,7 +1124,7 @@ def get_name_string(self, show_version=False):


def get_vanilla_model_patcher(m: ModelPatcher) -> ModelPatcher:
model = ModelPatcher(m.model, m.load_device, m.offload_device, m.size, m.current_device, weight_inplace_update=m.weight_inplace_update)
model = ModelPatcher(m.model, m.load_device, m.offload_device, m.size, weight_inplace_update=m.weight_inplace_update)
model.patches = {}
for k in m.patches:
model.patches[k] = m.patches[k][:]
Expand Down
7 changes: 6 additions & 1 deletion animatediff/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,9 +487,14 @@ def ad_callback(step, x0, x, total_steps):
iter_model = model.model
else:
iter_model = model
current_device = None
if hasattr(model, "current_device"): # backwards compatibility, for now
current_device = model.current_device
else:
current_device = model.model.device
iter_kwargs[IterationOptions.SAMPLER] = comfy.samplers.KSampler(
iter_model, steps=999, #steps=args[-7],
device=model.current_device, sampler=args[-5],
device=current_device, sampler=args[-5],
scheduler=args[-4], denoise=kwargs.get("denoise", None),
model_options=model.model_options)
del iter_model
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui-animatediff-evolved"
description = "Improved AnimateDiff integration for ComfyUI."
version = "1.0.11"
version = "1.0.12"
license = { file = "LICENSE" }
dependencies = []

Expand Down

0 comments on commit f297a20

Please sign in to comment.