Skip to content

Commit

Permalink
Make lora code a bit cleaner.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Dec 9, 2023
1 parent 9e41107 commit cb63e23
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 10 deletions.
14 changes: 7 additions & 7 deletions comfy/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def load_lora(lora, to_load):
if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name]
loaded_keys.add(mid_name)
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid))
loaded_keys.add(A_name)
loaded_keys.add(B_name)

Expand All @@ -64,7 +64,7 @@ def load_lora(lora, to_load):
loaded_keys.add(hada_t1_name)
loaded_keys.add(hada_t2_name)

patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2)
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2))
loaded_keys.add(hada_w1_a_name)
loaded_keys.add(hada_w1_b_name)
loaded_keys.add(hada_w2_a_name)
Expand Down Expand Up @@ -116,7 +116,7 @@ def load_lora(lora, to_load):
loaded_keys.add(lokr_t2_name)

if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2))


w_norm_name = "{}.w_norm".format(x)
Expand All @@ -126,21 +126,21 @@ def load_lora(lora, to_load):

if w_norm is not None:
loaded_keys.add(w_norm_name)
patch_dict[to_load[x]] = (w_norm,)
patch_dict[to_load[x]] = ("diff", (w_norm,))
if b_norm is not None:
loaded_keys.add(b_norm_name)
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,)
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))

diff_name = "{}.diff".format(x)
diff_weight = lora.get(diff_name, None)
if diff_weight is not None:
patch_dict[to_load[x]] = (diff_weight,)
patch_dict[to_load[x]] = ("diff", (diff_weight,))
loaded_keys.add(diff_name)

diff_bias_name = "{}.diff_b".format(x)
diff_bias = lora.get(diff_bias_name, None)
if diff_bias is not None:
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,)
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
loaded_keys.add(diff_bias_name)

for x in lora.keys():
Expand Down
14 changes: 11 additions & 3 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,19 @@ def calculate_weight(self, patches, weight, key):
v = (self.calculate_weight(v[1:], v[0].clone(), key), )

if len(v) == 1:
patch_type = "diff"
elif len(v) == 2:
patch_type = v[0]
v = v[1]

if patch_type == "diff":
w1 = v[0]
if alpha != 0.0:
if w1.shape != weight.shape:
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)
elif len(v) == 4: #lora/locon
elif patch_type == "lora": #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
if v[2] is not None:
Expand All @@ -237,7 +243,7 @@ def calculate_weight(self, patches, weight, key):
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
elif len(v) == 8: #lokr
elif patch_type == "lokr":
w1 = v[0]
w2 = v[1]
w1_a = v[3]
Expand Down Expand Up @@ -276,7 +282,7 @@ def calculate_weight(self, patches, weight, key):
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
else: #loha
elif patch_type == "loha":
w1a = v[0]
w1b = v[1]
if v[2] is not None:
Expand Down Expand Up @@ -305,6 +311,8 @@ def calculate_weight(self, patches, weight, key):
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
else:
print("patch type not recognized", patch_type, key)

return weight

Expand Down

0 comments on commit cb63e23

Please sign in to comment.