Skip to content

Commit

Permalink
bring back SDXLConfig accordion for dreambooth gui
Browse files Browse the repository at this point in the history
  • Loading branch information
b-fission committed Aug 6, 2024
1 parent df0c81d commit d96df32
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions kohya_gui/dreambooth_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .class_command_executor import CommandExecutor
from .class_huggingface import HuggingFace
from .class_metadata import MetaData
from .class_sdxl_parameters import SDXLParameters

from .dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab,
Expand Down Expand Up @@ -162,6 +163,8 @@ def save_configuration(
log_tracker_name,
log_tracker_config,
scale_v_pred_loss_like_noise_pred,
sdxl_cache_text_encoder_outputs,
sdxl_no_half_vae,
min_timestep,
max_timestep,
debiased_estimation_loss,
Expand Down Expand Up @@ -320,6 +323,8 @@ def open_configuration(
log_tracker_name,
log_tracker_config,
scale_v_pred_loss_like_noise_pred,
sdxl_cache_text_encoder_outputs,
sdxl_no_half_vae,
min_timestep,
max_timestep,
debiased_estimation_loss,
Expand Down Expand Up @@ -473,6 +478,8 @@ def train_model(
log_tracker_name,
log_tracker_config,
scale_v_pred_loss_like_noise_pred,
sdxl_cache_text_encoder_outputs,
sdxl_no_half_vae,
min_timestep,
max_timestep,
debiased_estimation_loss,
Expand Down Expand Up @@ -705,6 +712,9 @@ def train_model(
else:
run_cmd.append(rf"{scriptdir}/sd-scripts/train_db.py")

cache_text_encoder_outputs = sdxl and sdxl_cache_text_encoder_outputs
no_half_vae = sdxl and sdxl_no_half_vae

if max_data_loader_n_workers == "" or None:
max_data_loader_n_workers = 0
else:
Expand All @@ -724,6 +734,7 @@ def train_model(
"bucket_reso_steps": bucket_reso_steps,
"cache_latents": cache_latents,
"cache_latents_to_disk": cache_latents_to_disk,
"cache_text_encoder_outputs": cache_text_encoder_outputs,
"caption_dropout_every_n_epochs": int(caption_dropout_every_n_epochs),
"caption_dropout_rate": caption_dropout_rate,
"caption_extension": caption_extension,
Expand Down Expand Up @@ -789,6 +800,7 @@ def train_model(
"mixed_precision": mixed_precision,
"multires_noise_discount": multires_noise_discount,
"multires_noise_iterations": multires_noise_iterations if not 0 else None,
"no_half_vae": no_half_vae,
"no_token_padding": no_token_padding,
"noise_offset": noise_offset if not 0 else None,
"noise_offset_random_strength": noise_offset_random_strength,
Expand Down Expand Up @@ -981,6 +993,11 @@ def dreambooth_tab(
config=config,
)

# Add SDXL Parameters
sdxl_params = SDXLParameters(
source_model.sdxl_checkbox, config=config
)

with gr.Accordion("Advanced", open=False, elem_id="advanced_tab"):
advanced_training = AdvancedTraining(headless=headless, config=config)
advanced_training.color_aug.change(
Expand Down Expand Up @@ -1112,6 +1129,8 @@ def dreambooth_tab(
advanced_training.log_tracker_name,
advanced_training.log_tracker_config,
advanced_training.scale_v_pred_loss_like_noise_pred,
sdxl_params.sdxl_cache_text_encoder_outputs,
sdxl_params.sdxl_no_half_vae,
advanced_training.min_timestep,
advanced_training.max_timestep,
advanced_training.debiased_estimation_loss,
Expand Down

0 comments on commit d96df32

Please sign in to comment.