From 25f77f6ef04ee760506338e7e7f9835c28657c59 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sat, 17 Aug 2024 15:54:32 +0900 Subject: [PATCH] fix flux fine tuning to work --- README.md | 4 ++++ flux_train.py | 6 ++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index e231cc24e..2b7b110f3 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,10 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` + +Aug 17. 2024: +Added a script `flux_train.py` to train FLUX.1. The script is experimental and not an optimized version. It needs >28GB VRAM for training. + Aug 16, 2024: Added a script `networks/flux_merge_lora.py` to merge LoRA into FLUX.1 checkpoint. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details. diff --git a/flux_train.py b/flux_train.py index 2ca20ded2..d2a9b3f32 100644 --- a/flux_train.py +++ b/flux_train.py @@ -674,9 +674,7 @@ def optimizer_hook(parameter: torch.Tensor): # if is_main_process: flux = accelerator.unwrap_model(flux) clip_l = accelerator.unwrap_model(clip_l) - clip_g = accelerator.unwrap_model(clip_g) - if t5xxl is not None: - t5xxl = accelerator.unwrap_model(t5xxl) + t5xxl = accelerator.unwrap_model(t5xxl) accelerator.end_training() @@ -686,7 +684,7 @@ def optimizer_hook(parameter: torch.Tensor): del accelerator # この後メモリを使うのでこれは消す if is_main_process: - flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, flux, ae) + flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, flux) logger.info("model saved.")