Skip to content

Commit

Permalink
Fix dora.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Aug 23, 2024
1 parent 5d8bbb7 commit 7df42b9
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 21 deletions.
20 changes: 20 additions & 0 deletions comfy/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,26 @@ def model_lora_keys_unet(model, key_map={}):
return key_map


def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype):
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
lora_diff *= alpha
weight_calc = weight + lora_diff.type(weight.dtype)
weight_norm = (
weight_calc.transpose(0, 1)
.reshape(weight_calc.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
.transpose(0, 1)
)

weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
if strength != 1.0:
weight_calc -= weight
weight += strength * (weight_calc)
else:
weight[:] = weight_calc
return weight

def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
for p in patches:
strength = p[0]
Expand Down
21 changes: 0 additions & 21 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,27 +31,6 @@
from comfy.types import UnetWrapperFunction


def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype):
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
lora_diff *= alpha
weight_calc = weight + lora_diff.type(weight.dtype)
weight_norm = (
weight_calc.transpose(0, 1)
.reshape(weight_calc.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
.transpose(0, 1)
)

weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
if strength != 1.0:
weight_calc -= weight
weight += strength * (weight_calc)
else:
weight[:] = weight_calc
return weight


def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
to = model_options["transformer_options"].copy()

Expand Down

0 comments on commit 7df42b9

Please sign in to comment.