From 7df42b9a2364bae6822fbd9e9fa10cea2e319ba3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 23 Aug 2024 04:58:59 -0400 Subject: [PATCH] Fix dora. --- comfy/lora.py | 20 ++++++++++++++++++++ comfy/model_patcher.py | 21 --------------------- 2 files changed, 20 insertions(+), 21 deletions(-) diff --git a/comfy/lora.py b/comfy/lora.py index 9d8a7908aab..a3e7d9cc0c4 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -327,6 +327,26 @@ def model_lora_keys_unet(model, key_map={}): return key_map +def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype): + dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype) + lora_diff *= alpha + weight_calc = weight + lora_diff.type(weight.dtype) + weight_norm = ( + weight_calc.transpose(0, 1) + .reshape(weight_calc.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1)) + .transpose(0, 1) + ) + + weight_calc *= (dora_scale / weight_norm).type(weight.dtype) + if strength != 1.0: + weight_calc -= weight + weight += strength * (weight_calc) + else: + weight[:] = weight_calc + return weight + def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): for p in patches: strength = p[0] diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 1d83ba7c226..1f8100698ad 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -31,27 +31,6 @@ from comfy.types import UnetWrapperFunction -def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype): - dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype) - lora_diff *= alpha - weight_calc = weight + lora_diff.type(weight.dtype) - weight_norm = ( - weight_calc.transpose(0, 1) - .reshape(weight_calc.shape[1], -1) - .norm(dim=1, keepdim=True) - .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1)) - .transpose(0, 1) - ) - - weight_calc *= (dora_scale / weight_norm).type(weight.dtype) - if strength != 1.0: - weight_calc -= weight - weight += strength * (weight_calc) - else: - weight[:] = weight_calc - return weight - - def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): to = model_options["transformer_options"].copy()