diff --git a/assets/style.css b/assets/style.css index 41c2c438..f8cfe112 100644 --- a/assets/style.css +++ b/assets/style.css @@ -219,3 +219,28 @@ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */ } +.flux1_rank_layers_background { + background: #ece9e6; /* White background for clear theme */ + padding: 1em; + border-radius: 8px; + transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease; +} + +.flux1_rank_layers_background:hover { + background-color: #dddad7; /* Slightly darker on hover */ + border: 1px solid #ccc; + box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */ +} + +.dark .flux1_rank_layers_background { + background: #131c25; /* Dark background for dark theme */ + padding: 1em; + border-radius: 8px; + transition: background-color 0.3s ease, border 0.3s ease, box-shadow 0.3s ease; +} + +.dark .flux1_rank_layers_background:hover { + background-color: #131c25; /* Slightly darker on hover */ + border: 1px solid #000000; + box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow on hover */ +} \ No newline at end of file diff --git a/kohya_gui/class_flux1.py b/kohya_gui/class_flux1.py index da517d4f..a4d207f5 100644 --- a/kohya_gui/class_flux1.py +++ b/kohya_gui/class_flux1.py @@ -52,7 +52,7 @@ def noise_offset_type_change( outputs=self.ae, show_progress=False, ) - + self.clip_l = gr.Textbox( label="CLIP-L Path", placeholder="Path to CLIP-L model", @@ -90,20 +90,22 @@ def noise_offset_type_change( ) with gr.Row(): - + self.discrete_flow_shift = gr.Number( label="Discrete Flow Shift", value=self.config.get("flux1.discrete_flow_shift", 3.0), info="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0", minimum=-1024, maximum=1024, - step=.01, + step=0.01, interactive=True, ) self.model_prediction_type = gr.Dropdown( label="Model Prediction Type", choices=["raw", "additive", "sigma_scaled"], - value=self.config.get("flux1.timestep_sampling", "sigma_scaled"), + value=self.config.get( + "flux1.timestep_sampling", "sigma_scaled" + ), interactive=True, ) self.timestep_sampling = gr.Dropdown( @@ -156,10 +158,10 @@ def noise_offset_type_change( info="Guidance scale for Flux1", minimum=0, maximum=1024, - step=.1, + step=0.1, interactive=True, ) - self.t5xxl_max_token_length = gr.Number( + self.t5xxl_max_token_length = gr.Number( label="T5-XXL Max Token Length", value=self.config.get("flux1.t5xxl_max_token_length", 512), info="Max token length for T5-XXL", @@ -168,11 +170,19 @@ def noise_offset_type_change( step=1, interactive=True, ) - + self.enable_all_linear = gr.Checkbox( + label="Enable All Linear", + value=self.config.get("flux1.enable_all_linear", False), + info="(Only applicable to 'FLux1 OFT' LoRA) Target all linear connections in the MLP layer. The default is False, which targets only attention.", + interactive=True, + ) + with gr.Row(): self.flux1_cache_text_encoder_outputs = gr.Checkbox( label="Cache Text Encoder Outputs", - value=self.config.get("flux1.cache_text_encoder_outputs", False), + value=self.config.get( + "flux1.cache_text_encoder_outputs", False + ), info="Cache text encoder outputs to speed up inference", interactive=True, ) @@ -190,11 +200,13 @@ def noise_offset_type_change( info="[Experimentsl] Enable memory efficient save. We do not recommend using it unless you are familiar with the code.", interactive=True, ) - + with gr.Row(visible=True if finetuning else False): self.blockwise_fused_optimizers = gr.Checkbox( label="Blockwise Fused Optimizer", - value=self.config.get("flux1.blockwise_fused_optimizers", False), + value=self.config.get( + "flux1.blockwise_fused_optimizers", False + ), info="Enable blockwise optimizers for fused backward pass and optimizer step. Any optimizer can be used.", interactive=True, ) @@ -228,6 +240,62 @@ def noise_offset_type_change( info="Enables the fusing of the optimizer step into the backward pass for each parameter. Only Adafactor optimizer is supported.", interactive=True, ) + with gr.Accordion( + "Rank for layers", + open=False, + visible=False if finetuning else True, + elem_classes=["flux1_rank_layers_background"], + ): + with gr.Row(): + self.img_attn_dim = gr.Textbox( + label="img_attn_dim", + value=self.config.get("flux1.img_attn_dim", ""), + interactive=True, + ) + self.img_mlp_dim = gr.Textbox( + label="img_mlp_dim", + value=self.config.get("flux1.img_mlp_dim", ""), + interactive=True, + ) + self.img_mod_dim = gr.Textbox( + label="img_mod_dim", + value=self.config.get("flux1.img_mod_dim", ""), + interactive=True, + ) + self.single_dim = gr.Textbox( + label="single_dim", + value=self.config.get("flux1.single_dim", ""), + interactive=True, + ) + with gr.Row(): + self.txt_attn_dim = gr.Textbox( + label="txt_attn_dim", + value=self.config.get("flux1.txt_attn_dim", ""), + interactive=True, + ) + self.txt_mlp_dim = gr.Textbox( + label="txt_mlp_dim", + value=self.config.get("flux1.txt_mlp_dim", ""), + interactive=True, + ) + self.txt_mod_dim = gr.Textbox( + label="txt_mod_dim", + value=self.config.get("flux1.txt_mod_dim", ""), + interactive=True, + ) + self.single_mod_dim = gr.Textbox( + label="single_mod_dim", + value=self.config.get("flux1.single_mod_dim", ""), + interactive=True, + ) + with gr.Row(): + self.in_dims = gr.Textbox( + label="in_dims", + value=self.config.get("flux1.in_dims", ""), + placeholder="e.g., [4,0,0,0,4]", + info="Each number corresponds to img_in, time_in, vector_in, guidance_in, txt_in. The above example applies LoRA to all conditioning layers, with rank 4 for img_in, 2 for time_in, vector_in, guidance_in, and 4 for txt_in.", + interactive=True, + ) self.flux1_checkbox.change( lambda flux1_checkbox: gr.Accordion(visible=flux1_checkbox), diff --git a/kohya_gui/lora_gui.py b/kohya_gui/lora_gui.py index bf288a82..f9706313 100644 --- a/kohya_gui/lora_gui.py +++ b/kohya_gui/lora_gui.py @@ -277,12 +277,22 @@ def save_configuration( split_mode, train_blocks, t5xxl_max_token_length, + enable_all_linear, guidance_scale, mem_eff_save, apply_t5_attn_mask, split_qkv, train_t5xxl, cpu_offload_checkpointing, + img_attn_dim, + img_mlp_dim, + img_mod_dim, + single_dim, + txt_attn_dim, + txt_mlp_dim, + txt_mod_dim, + single_mod_dim, + in_dims, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -528,12 +538,23 @@ def open_configuration( split_mode, train_blocks, t5xxl_max_token_length, + enable_all_linear, guidance_scale, mem_eff_save, apply_t5_attn_mask, split_qkv, train_t5xxl, cpu_offload_checkpointing, + img_attn_dim, + img_mlp_dim, + img_mod_dim, + single_dim, + txt_attn_dim, + txt_mlp_dim, + txt_mod_dim, + single_mod_dim, + in_dims, + ## training_preset, ): @@ -812,12 +833,22 @@ def train_model( split_mode, train_blocks, t5xxl_max_token_length, + enable_all_linear, guidance_scale, mem_eff_save, apply_t5_attn_mask, split_qkv, train_t5xxl, cpu_offload_checkpointing, + img_attn_dim, + img_mlp_dim, + img_mod_dim, + single_dim, + txt_attn_dim, + txt_mlp_dim, + txt_mod_dim, + single_mod_dim, + in_dims, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -845,7 +876,6 @@ def train_model( if flux1_checkbox: log.info(f"Validating lora type is Flux1 if flux1 checkbox is checked...") - print(LoRA_type) if (LoRA_type != "Flux1") and (LoRA_type != "Flux1 OFT"): log.error("LoRA type must be set to 'Flux1' or 'Flux1 OFT' if Flux1 checkbox is checked.") return TRAIN_BUTTON_VISIBLE @@ -1135,7 +1165,17 @@ def train_model( if LoRA_type == "Flux1": # Add a list of supported network arguments for Flux1 below when supported - kohya_lora_var_list = [] + kohya_lora_var_list = [ + "img_attn_dim", + "img_mlp_dim", + "img_mod_dim", + "single_dim", + "txt_attn_dim", + "txt_mlp_dim", + "txt_mod_dim", + "single_mod_dim", + "in_dims", + ] network_module = "networks.lora_flux" kohya_lora_vars = { key: value @@ -1160,7 +1200,9 @@ def train_model( if LoRA_type == "Flux1 OFT": # Add a list of supported network arguments for Flux1 OFT below when supported - kohya_lora_var_list = [] + kohya_lora_var_list = [ + "enable_all_linear", + ] network_module = "networks.oft_flux" kohya_lora_vars = { key: value @@ -1602,12 +1644,12 @@ def lora_tab( config=config, ) + with gr.Accordion("Folders", open=True), gr.Group(): + folders = Folders(headless=headless, config=config) + with gr.Accordion("Metadata", open=False), gr.Group(): metadata = MetaData(config=config) - with gr.Accordion("Folders", open=False), gr.Group(): - folders = Folders(headless=headless, config=config) - with gr.Accordion("Dataset Preparation", open=False): gr.Markdown( "This section provide Dreambooth tools to help setup your dataset..." @@ -2675,12 +2717,22 @@ def update_LoRA_settings( flux1_training.split_mode, flux1_training.train_blocks, flux1_training.t5xxl_max_token_length, + flux1_training.enable_all_linear, flux1_training.guidance_scale, flux1_training.mem_eff_save, flux1_training.apply_t5_attn_mask, flux1_training.split_qkv, flux1_training.train_t5xxl, flux1_training.cpu_offload_checkpointing, + flux1_training.img_attn_dim, + flux1_training.img_mlp_dim, + flux1_training.img_mod_dim, + flux1_training.single_dim, + flux1_training.txt_attn_dim, + flux1_training.txt_mlp_dim, + flux1_training.txt_mod_dim, + flux1_training.single_mod_dim, + flux1_training.in_dims, ] configuration.button_open_config.click(