diff --git a/adv_control/control_reference.py b/adv_control/control_reference.py index 4b3c45b..815d651 100644 --- a/adv_control/control_reference.py +++ b/adv_control/control_reference.py @@ -62,7 +62,10 @@ def refcn_sample(model: ModelPatcher, *args, **kwargs): for i, module in enumerate(attn_modules): injection_holder = InjectionBasicTransformerBlockHolder(block=module, idx=i) injection_holder.attn_weight = float(i) / float(len(attn_modules)) - module._forward = _forward_inject_BasicTransformerBlock.__get__(module, type(module)) + if hasattr(module, "_forward"): # backward compatibility + module._forward = _forward_inject_BasicTransformerBlock.__get__(module, type(module)) + else: + module.forward = _forward_inject_BasicTransformerBlock.__get__(module, type(module)) module.injection_holder = injection_holder reference_injections.attn_modules.append(module) # figure out which module is middle block @@ -430,14 +433,20 @@ def clean(self): class InjectionBasicTransformerBlockHolder: def __init__(self, block: BasicTransformerBlock, idx=None): - self.original_forward = block._forward + if hasattr(block, "_forward"): # backward compatibility + self.original_forward = block._forward + else: + self.original_forward = block.forward self.idx = idx self.attn_weight = 1.0 self.is_middle = False self.bank_styles = BankStylesBasicTransformerBlock() def restore(self, block: BasicTransformerBlock): - block._forward = self.original_forward + if hasattr(block, "_forward"): # backward compatibility + block._forward = self.original_forward + else: + block.forward = self.original_forward def clean(self): self.bank_styles.clean() diff --git a/pyproject.toml b/pyproject.toml index efffd8f..9df3e69 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "comfyui-advanced-controlnet" description = "Nodes for scheduling ControlNet strength across timesteps and batched latents, as well as applying custom weights and attention masks." -version = "1.0.1" +version = "1.0.2" license = "LICENSE" dependencies = []