Skip to content

Commit

Permalink
Add support for Rank for layers
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Sep 14, 2024
1 parent 6c5c9d4 commit d24fae1
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 16 deletions.
25 changes: 25 additions & 0 deletions assets/style.css
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,28 @@
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
}

.flux1_rank_layers_background {
background: #ece9e6; /* White background for clear theme */
padding: 1em;
border-radius: 8px;
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
}

.flux1_rank_layers_background:hover {
background-color: #dddad7; /* Slightly darker on hover */
border: 1px solid #ccc;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
}

.dark .flux1_rank_layers_background {
background: #131c25; /* Dark background for dark theme */
padding: 1em;
border-radius: 8px;
transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease;
}

.dark .flux1_rank_layers_background:hover {
background-color: #131c25; /* Slightly darker on hover */
border: 1px solid #000000;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */
}
88 changes: 78 additions & 10 deletions kohya_gui/class_flux1.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def noise_offset_type_change(
outputs=self.ae,
show_progress=False,
)

self.clip_l = gr.Textbox(
label="CLIP-L Path",
placeholder="Path to CLIP-L model",
Expand Down Expand Up @@ -90,20 +90,22 @@ def noise_offset_type_change(
)

with gr.Row():

self.discrete_flow_shift = gr.Number(
label="Discrete Flow Shift",
value=self.config.get("flux1.discrete_flow_shift", 3.0),
info="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0",
minimum=-1024,
maximum=1024,
step=.01,
step=0.01,
interactive=True,
)
self.model_prediction_type = gr.Dropdown(
label="Model Prediction Type",
choices=["raw", "additive", "sigma_scaled"],
value=self.config.get("flux1.timestep_sampling", "sigma_scaled"),
value=self.config.get(
"flux1.timestep_sampling", "sigma_scaled"
),
interactive=True,
)
self.timestep_sampling = gr.Dropdown(
Expand Down Expand Up @@ -156,10 +158,10 @@ def noise_offset_type_change(
info="Guidance scale for Flux1",
minimum=0,
maximum=1024,
step=.1,
step=0.1,
interactive=True,
)
self.t5xxl_max_token_length = gr.Number(
self.t5xxl_max_token_length = gr.Number(
label="T5-XXL Max Token Length",
value=self.config.get("flux1.t5xxl_max_token_length", 512),
info="Max token length for T5-XXL",
Expand All @@ -168,11 +170,19 @@ def noise_offset_type_change(
step=1,
interactive=True,
)

self.enable_all_linear = gr.Checkbox(
label="Enable All Linear",
value=self.config.get("flux1.enable_all_linear", False),
info="(Only applicable to 'FLux1 OFT' LoRA) Target all linear connections in the MLP layer. The default is False, which targets only attention.",
interactive=True,
)

with gr.Row():
self.flux1_cache_text_encoder_outputs = gr.Checkbox(
label="Cache Text Encoder Outputs",
value=self.config.get("flux1.cache_text_encoder_outputs", False),
value=self.config.get(
"flux1.cache_text_encoder_outputs", False
),
info="Cache text encoder outputs to speed up inference",
interactive=True,
)
Expand All @@ -190,11 +200,13 @@ def noise_offset_type_change(
info="[Experimentsl] Enable memory efficient save. We do not recommend using it unless you are familiar with the code.",
interactive=True,
)

with gr.Row(visible=True if finetuning else False):
self.blockwise_fused_optimizers = gr.Checkbox(
label="Blockwise Fused Optimizer",
value=self.config.get("flux1.blockwise_fused_optimizers", False),
value=self.config.get(
"flux1.blockwise_fused_optimizers", False
),
info="Enable blockwise optimizers for fused backward pass and optimizer step. Any optimizer can be used.",
interactive=True,
)
Expand Down Expand Up @@ -228,6 +240,62 @@ def noise_offset_type_change(
info="Enables the fusing of the optimizer step into the backward pass for each parameter. Only Adafactor optimizer is supported.",
interactive=True,
)
with gr.Accordion(
"Rank for layers",
open=False,
visible=False if finetuning else True,
elem_classes=["flux1_rank_layers_background"],
):
with gr.Row():
self.img_attn_dim = gr.Textbox(
label="img_attn_dim",
value=self.config.get("flux1.img_attn_dim", ""),
interactive=True,
)
self.img_mlp_dim = gr.Textbox(
label="img_mlp_dim",
value=self.config.get("flux1.img_mlp_dim", ""),
interactive=True,
)
self.img_mod_dim = gr.Textbox(
label="img_mod_dim",
value=self.config.get("flux1.img_mod_dim", ""),
interactive=True,
)
self.single_dim = gr.Textbox(
label="single_dim",
value=self.config.get("flux1.single_dim", ""),
interactive=True,
)
with gr.Row():
self.txt_attn_dim = gr.Textbox(
label="txt_attn_dim",
value=self.config.get("flux1.txt_attn_dim", ""),
interactive=True,
)
self.txt_mlp_dim = gr.Textbox(
label="txt_mlp_dim",
value=self.config.get("flux1.txt_mlp_dim", ""),
interactive=True,
)
self.txt_mod_dim = gr.Textbox(
label="txt_mod_dim",
value=self.config.get("flux1.txt_mod_dim", ""),
interactive=True,
)
self.single_mod_dim = gr.Textbox(
label="single_mod_dim",
value=self.config.get("flux1.single_mod_dim", ""),
interactive=True,
)
with gr.Row():
self.in_dims = gr.Textbox(
label="in_dims",
value=self.config.get("flux1.in_dims", ""),
placeholder="e.g., [4,0,0,0,4]",
info="Each number corresponds to img_in, time_in, vector_in, guidance_in, txt_in. The above example applies LoRA to all conditioning layers, with rank 4 for img_in, 2 for time_in, vector_in, guidance_in, and 4 for txt_in.",
interactive=True,
)

self.flux1_checkbox.change(
lambda flux1_checkbox: gr.Accordion(visible=flux1_checkbox),
Expand Down
64 changes: 58 additions & 6 deletions kohya_gui/lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,12 +277,22 @@ def save_configuration(
split_mode,
train_blocks,
t5xxl_max_token_length,
enable_all_linear,
guidance_scale,
mem_eff_save,
apply_t5_attn_mask,
split_qkv,
train_t5xxl,
cpu_offload_checkpointing,
img_attn_dim,
img_mlp_dim,
img_mod_dim,
single_dim,
txt_attn_dim,
txt_mlp_dim,
txt_mod_dim,
single_mod_dim,
in_dims,
):
# Get list of function parameters and values
parameters = list(locals().items())
Expand Down Expand Up @@ -528,12 +538,23 @@ def open_configuration(
split_mode,
train_blocks,
t5xxl_max_token_length,
enable_all_linear,
guidance_scale,
mem_eff_save,
apply_t5_attn_mask,
split_qkv,
train_t5xxl,
cpu_offload_checkpointing,
img_attn_dim,
img_mlp_dim,
img_mod_dim,
single_dim,
txt_attn_dim,
txt_mlp_dim,
txt_mod_dim,
single_mod_dim,
in_dims,

##
training_preset,
):
Expand Down Expand Up @@ -812,12 +833,22 @@ def train_model(
split_mode,
train_blocks,
t5xxl_max_token_length,
enable_all_linear,
guidance_scale,
mem_eff_save,
apply_t5_attn_mask,
split_qkv,
train_t5xxl,
cpu_offload_checkpointing,
img_attn_dim,
img_mlp_dim,
img_mod_dim,
single_dim,
txt_attn_dim,
txt_mlp_dim,
txt_mod_dim,
single_mod_dim,
in_dims,
):
# Get list of function parameters and values
parameters = list(locals().items())
Expand Down Expand Up @@ -845,7 +876,6 @@ def train_model(

if flux1_checkbox:
log.info(f"Validating lora type is Flux1 if flux1 checkbox is checked...")
print(LoRA_type)
if (LoRA_type != "Flux1") and (LoRA_type != "Flux1 OFT"):
log.error("LoRA type must be set to 'Flux1' or 'Flux1 OFT' if Flux1 checkbox is checked.")
return TRAIN_BUTTON_VISIBLE
Expand Down Expand Up @@ -1135,7 +1165,17 @@ def train_model(

if LoRA_type == "Flux1":
# Add a list of supported network arguments for Flux1 below when supported
kohya_lora_var_list = []
kohya_lora_var_list = [
"img_attn_dim",
"img_mlp_dim",
"img_mod_dim",
"single_dim",
"txt_attn_dim",
"txt_mlp_dim",
"txt_mod_dim",
"single_mod_dim",
"in_dims",
]
network_module = "networks.lora_flux"
kohya_lora_vars = {
key: value
Expand All @@ -1160,7 +1200,9 @@ def train_model(

if LoRA_type == "Flux1 OFT":
# Add a list of supported network arguments for Flux1 OFT below when supported
kohya_lora_var_list = []
kohya_lora_var_list = [
"enable_all_linear",
]
network_module = "networks.oft_flux"
kohya_lora_vars = {
key: value
Expand Down Expand Up @@ -1602,12 +1644,12 @@ def lora_tab(
config=config,
)

with gr.Accordion("Folders", open=True), gr.Group():
folders = Folders(headless=headless, config=config)

with gr.Accordion("Metadata", open=False), gr.Group():
metadata = MetaData(config=config)

with gr.Accordion("Folders", open=False), gr.Group():
folders = Folders(headless=headless, config=config)

with gr.Accordion("Dataset Preparation", open=False):
gr.Markdown(
"This section provide Dreambooth tools to help setup your dataset..."
Expand Down Expand Up @@ -2675,12 +2717,22 @@ def update_LoRA_settings(
flux1_training.split_mode,
flux1_training.train_blocks,
flux1_training.t5xxl_max_token_length,
flux1_training.enable_all_linear,
flux1_training.guidance_scale,
flux1_training.mem_eff_save,
flux1_training.apply_t5_attn_mask,
flux1_training.split_qkv,
flux1_training.train_t5xxl,
flux1_training.cpu_offload_checkpointing,
flux1_training.img_attn_dim,
flux1_training.img_mlp_dim,
flux1_training.img_mod_dim,
flux1_training.single_dim,
flux1_training.txt_attn_dim,
flux1_training.txt_mlp_dim,
flux1_training.txt_mod_dim,
flux1_training.single_mod_dim,
flux1_training.in_dims,
]

configuration.button_open_config.click(
Expand Down

0 comments on commit d24fae1

Please sign in to comment.