From 614b7e731f7f9fdcf11eeb46e0623b0977a7e634 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 9 Dec 2023 18:15:26 -0500 Subject: [PATCH] Implement GLora. --- comfy/lora.py | 11 +++++++++++ comfy/model_patcher.py | 10 ++++++++++ 2 files changed, 21 insertions(+) diff --git a/comfy/lora.py b/comfy/lora.py index ecd518084a5..5e4009b47f9 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -118,6 +118,17 @@ def load_lora(lora, to_load): if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)) + #glora + a1_name = "{}.a1.weight".format(x) + a2_name = "{}.a2.weight".format(x) + b1_name = "{}.b1.weight".format(x) + b2_name = "{}.b2.weight".format(x) + if a1_name in lora: + patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha)) + loaded_keys.add(a1_name) + loaded_keys.add(a2_name) + loaded_keys.add(b1_name) + loaded_keys.add(b2_name) w_norm_name = "{}.w_norm".format(x) b_norm_name = "{}.b_norm".format(x) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index d78cdfd4dfd..55ca913ec78 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -311,6 +311,16 @@ def calculate_weight(self, patches, weight, key): weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) except Exception as e: print("ERROR", key, e) + elif patch_type == "glora": + if v[4] is not None: + alpha *= v[4] / v[0].shape[0] + + a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32) + a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32) + b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32) + b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32) + + weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype) else: print("patch type not recognized", patch_type, key)