Skip to content

Commit

Permalink
Merge pull request #1482 from kohya-ss/flux-merge-lora
Browse files Browse the repository at this point in the history
Flux merge lora
  • Loading branch information
kohya-ss committed Aug 20, 2024
2 parents c62c95e + dbed512 commit 388b3b4
Show file tree
Hide file tree
Showing 2 changed files with 277 additions and 14 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
The command to install PyTorch is as follows:
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`

Aug 20, 2024 (update 2):
`flux_merge_lora.py` now supports LoRA from AI-toolkit (Diffusers based keys). Specify `--diffusers` option to merge LoRA with Diffusers based keys. Thanks to exveria1015!

Aug 20, 2024:
FLUX.1 supports multi-resolution inference, so training at multiple resolutions may be possible and the results may be improved (like 1024x1024, 768x768 and 512x512 ... you can use any resolution).

Expand Down
288 changes: 274 additions & 14 deletions networks/flux_merge_lora.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
import math
import argparse
import math
import os
import time

import torch
from safetensors import safe_open
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from library import sai_model_spec, train_util
import networks.lora_flux as lora_flux

from library.utils import setup_logging

setup_logging()
import logging

logger = logging.getLogger(__name__)

import lora_flux as lora_flux
from library import sai_model_spec, train_util


def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == ".safetensors":
Expand Down Expand Up @@ -60,7 +63,7 @@ def merge_to_flux_model(loading_device, working_device, flux_model, models, rati
lora_sd, _ = load_state_dict(model, merge_dtype) # loading on CPU

logger.info(f"merging...")
for key in tqdm(lora_sd.keys()):
for key in tqdm(list(lora_sd.keys())):
if "lora_down" in key:
lora_name = key[: key.rfind(".lora_down")]
up_key = key.replace("lora_down", "lora_up")
Expand All @@ -70,11 +73,11 @@ def merge_to_flux_model(loading_device, working_device, flux_model, models, rati
logger.warning(f"no module found for LoRA weight: {key}. LoRA for Text Encoder is not supported yet.")
continue

down_weight = lora_sd[key]
up_weight = lora_sd[up_key]
down_weight = lora_sd.pop(key)
up_weight = lora_sd.pop(up_key)

dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
alpha = lora_sd.pop(alpha_key, dim)
scale = alpha / dim

# W <- W + U * D
Expand Down Expand Up @@ -111,6 +114,253 @@ def merge_to_flux_model(loading_device, working_device, flux_model, models, rati
del down_weight
del weight

if len(lora_sd) > 0:
logger.warning(f"Unused keys in LoRA model: {list(lora_sd.keys())}")

return flux_state_dict


def merge_to_flux_model_diffusers(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype):
logger.info(f"loading keys from FLUX.1 model: {flux_model}")
flux_state_dict = load_file(flux_model, device=loading_device)

def create_key_map(n_double_layers, n_single_layers):
key_map = {}
for index in range(n_double_layers):
prefix_from = f"transformer_blocks.{index}"
prefix_to = f"double_blocks.{index}"

for end in ("weight", "bias"):
k = f"{prefix_from}.attn."
qkv_img = f"{prefix_to}.img_attn.qkv.{end}"
qkv_txt = f"{prefix_to}.txt_attn.qkv.{end}"

key_map[f"{k}to_q.{end}"] = qkv_img
key_map[f"{k}to_k.{end}"] = qkv_img
key_map[f"{k}to_v.{end}"] = qkv_img
key_map[f"{k}add_q_proj.{end}"] = qkv_txt
key_map[f"{k}add_k_proj.{end}"] = qkv_txt
key_map[f"{k}add_v_proj.{end}"] = qkv_txt

block_map = {
"attn.to_out.0.weight": "img_attn.proj.weight",
"attn.to_out.0.bias": "img_attn.proj.bias",
"norm1.linear.weight": "img_mod.lin.weight",
"norm1.linear.bias": "img_mod.lin.bias",
"norm1_context.linear.weight": "txt_mod.lin.weight",
"norm1_context.linear.bias": "txt_mod.lin.bias",
"attn.to_add_out.weight": "txt_attn.proj.weight",
"attn.to_add_out.bias": "txt_attn.proj.bias",
"ff.net.0.proj.weight": "img_mlp.0.weight",
"ff.net.0.proj.bias": "img_mlp.0.bias",
"ff.net.2.weight": "img_mlp.2.weight",
"ff.net.2.bias": "img_mlp.2.bias",
"ff_context.net.0.proj.weight": "txt_mlp.0.weight",
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
"ff_context.net.2.weight": "txt_mlp.2.weight",
"ff_context.net.2.bias": "txt_mlp.2.bias",
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
}

for k, v in block_map.items():
key_map[f"{prefix_from}.{k}"] = f"{prefix_to}.{v}"

for index in range(n_single_layers):
prefix_from = f"single_transformer_blocks.{index}"
prefix_to = f"single_blocks.{index}"

for end in ("weight", "bias"):
k = f"{prefix_from}.attn."
qkv = f"{prefix_to}.linear1.{end}"
key_map[f"{k}to_q.{end}"] = qkv
key_map[f"{k}to_k.{end}"] = qkv
key_map[f"{k}to_v.{end}"] = qkv
key_map[f"{prefix_from}.proj_mlp.{end}"] = qkv

block_map = {
"norm.linear.weight": "modulation.lin.weight",
"norm.linear.bias": "modulation.lin.bias",
"proj_out.weight": "linear2.weight",
"proj_out.bias": "linear2.bias",
"attn.norm_q.weight": "norm.query_norm.scale",
"attn.norm_k.weight": "norm.key_norm.scale",
}

for k, v in block_map.items():
key_map[f"{prefix_from}.{k}"] = f"{prefix_to}.{v}"

# add as-is keys
values = list([(v if isinstance(v, str) else v[0]) for v in set(key_map.values())])
values.sort()
key_map.update({v: v for v in values})

return key_map

key_map = create_key_map(18, 38) # 18 double layers, 38 single layers

def find_matching_key(flux_dict, lora_key):
lora_key = lora_key.replace("diffusion_model.", "")
lora_key = lora_key.replace("transformer.", "")
lora_key = lora_key.replace("lora_A", "lora_down").replace("lora_B", "lora_up")
lora_key = lora_key.replace("single_transformer_blocks", "single_blocks")
lora_key = lora_key.replace("transformer_blocks", "double_blocks")

double_block_map = {
"attn.to_out.0": "img_attn.proj",
"norm1.linear": "img_mod.lin",
"norm1_context.linear": "txt_mod.lin",
"attn.to_add_out": "txt_attn.proj",
"ff.net.0.proj": "img_mlp.0",
"ff.net.2": "img_mlp.2",
"ff_context.net.0.proj": "txt_mlp.0",
"ff_context.net.2": "txt_mlp.2",
"attn.norm_q": "img_attn.norm.query_norm",
"attn.norm_k": "img_attn.norm.key_norm",
"attn.norm_added_q": "txt_attn.norm.query_norm",
"attn.norm_added_k": "txt_attn.norm.key_norm",
"attn.to_q": "img_attn.qkv",
"attn.to_k": "img_attn.qkv",
"attn.to_v": "img_attn.qkv",
"attn.add_q_proj": "txt_attn.qkv",
"attn.add_k_proj": "txt_attn.qkv",
"attn.add_v_proj": "txt_attn.qkv",
}
single_block_map = {
"norm.linear": "modulation.lin",
"proj_out": "linear2",
"attn.norm_q": "norm.query_norm",
"attn.norm_k": "norm.key_norm",
"attn.to_q": "linear1",
"attn.to_k": "linear1",
"attn.to_v": "linear1",
"proj_mlp": "linear1",
}

# same key exists in both single_block_map and double_block_map, so we must care about single/double
# print("lora_key before double_block_map", lora_key)
for old, new in double_block_map.items():
if "double" in lora_key:
lora_key = lora_key.replace(old, new)
# print("lora_key before single_block_map", lora_key)
for old, new in single_block_map.items():
if "single" in lora_key:
lora_key = lora_key.replace(old, new)
# print("lora_key after mapping", lora_key)

if lora_key in key_map:
flux_key = key_map[lora_key]
logger.info(f"Found matching key: {flux_key}")
return flux_key

# If not found in key_map, try partial matching
potential_key = lora_key + ".weight"
logger.info(f"Searching for key: {potential_key}")
matches = [k for k in flux_dict.keys() if potential_key in k]
if matches:
logger.info(f"Found matching key: {matches[0]}")
return matches[0]
return None

merged_keys = set()
for model, ratio in zip(models, ratios):
logger.info(f"loading: {model}")
lora_sd, _ = load_state_dict(model, merge_dtype)

logger.info("merging...")
for key in lora_sd.keys():
if "lora_down" in key or "lora_A" in key:
lora_name = key[: key.rfind(".lora_down" if "lora_down" in key else ".lora_A")]
up_key = key.replace("lora_down", "lora_up").replace("lora_A", "lora_B")
alpha_key = key[: key.index("lora_down" if "lora_down" in key else "lora_A")] + "alpha"

logger.info(f"Processing LoRA key: {lora_name}")
flux_key = find_matching_key(flux_state_dict, lora_name)

if flux_key is None:
logger.warning(f"no module found for LoRA weight: {key}")
continue

logger.info(f"Merging LoRA key {lora_name} into Flux key {flux_key}")

down_weight = lora_sd[key]
up_weight = lora_sd[up_key]

dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim

weight = flux_state_dict[flux_key]

weight = weight.to(working_device, merge_dtype)
up_weight = up_weight.to(working_device, merge_dtype)
down_weight = down_weight.to(working_device, merge_dtype)

# print(up_weight.size(), down_weight.size(), weight.size())

if lora_name.startswith("transformer."):
if "qkv" in flux_key or "linear1" in flux_key: # combined qkv or qkv+mlp
update = ratio * (up_weight @ down_weight) * scale
# print(update.shape)

if "img_attn" in flux_key or "txt_attn" in flux_key:
q, k, v = torch.chunk(weight, 3, dim=0)
if "to_q" in lora_name or "add_q_proj" in lora_name:
q += update.reshape(q.shape)
elif "to_k" in lora_name or "add_k_proj" in lora_name:
k += update.reshape(k.shape)
elif "to_v" in lora_name or "add_v_proj" in lora_name:
v += update.reshape(v.shape)
weight = torch.cat([q, k, v], dim=0)
elif "linear1" in flux_key:
q, k, v = torch.chunk(weight[: int(update.shape[-1] * 3)], 3, dim=0)
mlp = weight[int(update.shape[-1] * 3) :]
# print(q.shape, k.shape, v.shape, mlp.shape)
if "to_q" in lora_name:
q += update.reshape(q.shape)
elif "to_k" in lora_name:
k += update.reshape(k.shape)
elif "to_v" in lora_name:
v += update.reshape(v.shape)
elif "proj_mlp" in lora_name:
mlp += update.reshape(mlp.shape)
weight = torch.cat([q, k, v, mlp], dim=0)
else:
if len(weight.size()) == 2:
weight = weight + ratio * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
weight = (
weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
else:
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
weight = weight + ratio * conved * scale
else:
if len(weight.size()) == 2:
weight = weight + ratio * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
weight = (
weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
else:
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
weight = weight + ratio * conved * scale

flux_state_dict[flux_key] = weight.to(loading_device, save_dtype)
merged_keys.add(flux_key)
del up_weight
del down_weight
del weight

logger.info(f"Merged keys: {sorted(list(merged_keys))}")
return flux_state_dict


Expand Down Expand Up @@ -155,7 +405,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")

# merge
logger.info(f"merging...")
logger.info("merging...")
for key in tqdm(lora_sd.keys()):
if "alpha" in key:
continue
Expand All @@ -178,7 +428,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
if key in merged_sd:
assert (
merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
), f"weights shape mismatch, different dims? / 重みのサイズが合いません。dimが異なる可能性があります。"
), "weights shape mismatch, different dims? / 重みのサイズが合いません。dimが異なる可能性があります。"
if concat_dim is not None:
merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim)
else:
Expand Down Expand Up @@ -226,7 +476,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
def merge(args):
assert len(args.models) == len(
args.ratios
), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
), "number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"

def str_to_dtype(p):
if p == "float":
Expand All @@ -248,9 +498,14 @@ def str_to_dtype(p):
os.makedirs(dest_dir)

if args.flux_model is not None:
state_dict = merge_to_flux_model(
args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype
)
if not args.diffusers:
state_dict = merge_to_flux_model(
args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype
)
else:
state_dict = merge_to_flux_model_diffusers(
args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype
)

if args.no_metadata:
sai_metadata = None
Expand All @@ -267,7 +522,7 @@ def str_to_dtype(p):
else:
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)

logger.info(f"calculating hashes and creating metadata...")
logger.info("calculating hashes and creating metadata...")

model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
Expand Down Expand Up @@ -350,6 +605,11 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="shuffle lora weight./ " + "LoRAの重みをシャッフルする",
)
parser.add_argument(
"--diffusers",
action="store_true",
help="merge Diffusers (?) LoRA models / Diffusers (?) LoRAモデルをマージする",
)

return parser

Expand Down

0 comments on commit 388b3b4

Please sign in to comment.