Skip to content

Commit

Permalink
Implement GLora.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Dec 9, 2023
1 parent cb63e23 commit 614b7e7
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
11 changes: 11 additions & 0 deletions comfy/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,17 @@ def load_lora(lora, to_load):
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2))

#glora
a1_name = "{}.a1.weight".format(x)
a2_name = "{}.a2.weight".format(x)
b1_name = "{}.b1.weight".format(x)
b2_name = "{}.b2.weight".format(x)
if a1_name in lora:
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha))
loaded_keys.add(a1_name)
loaded_keys.add(a2_name)
loaded_keys.add(b1_name)
loaded_keys.add(b2_name)

w_norm_name = "{}.w_norm".format(x)
b_norm_name = "{}.b_norm".format(x)
Expand Down
10 changes: 10 additions & 0 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,16 @@ def calculate_weight(self, patches, weight, key):
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
elif patch_type == "glora":
if v[4] is not None:
alpha *= v[4] / v[0].shape[0]

a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)

weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
else:
print("patch type not recognized", patch_type, key)

Expand Down

0 comments on commit 614b7e7

Please sign in to comment.