Skip to content

Commit

Permalink
add memory efficient training for FLUX.1
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Aug 18, 2024
1 parent 25f77f6 commit ef535ec
Show file tree
Hide file tree
Showing 3 changed files with 354 additions and 79 deletions.
64 changes: 56 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ The command to install PyTorch is as follows:
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`


Aug 17. 2024:
Aug 18, 2024:
Memory-efficient training based on 2kpr's implementation is implemented in `flux_train.py`. Thanks to 2kpr. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details.


Aug 17, 2024:
Added a script `flux_train.py` to train FLUX.1. The script is experimental and not an optimized version. It needs >28GB VRAM for training.

Aug 16, 2024:
Expand Down Expand Up @@ -39,11 +43,23 @@ Aug 11, 2024: Fix `--apply_t5_attn_mask` option to work. Please remove and re-ge

Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI.

We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below, settings are based on [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit). It will work with 24GB VRAM GPUs.

### FLUX.1 LoRA training

We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below, settings are based on [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit). It will work with 24GB VRAM GPUs.

```
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 --loss_type l2
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py
--pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors
--ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers
--max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16
--network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4
--network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base
--highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml
--output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigmoid
--model_prediction_type raw --guidance_scale 1.0 --loss_type l2
```
(The command is multi-line for readability. Please combine it into one line.)

The training can be done with 16GB VRAM GPUs with Adafactor optimizer. Please use settings like below:

Expand Down Expand Up @@ -80,12 +96,44 @@ The trained LoRA model can be used with ComfyUI.

The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options.

Aug 12: `--interactive` option is now working.

```
python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0
```

### FLUX.1 fine-tuning

Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GPUs, and 64GB main memory is recommended.

```
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train.py
--pretrained_model_name_or_path flux1-dev.sft --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.sft
--mixed_precision bf16 --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2
--seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16
--dataset_config dataset_1024_bs1.toml --output_dir path/to/output/dir --output_name test-bf16
--learning_rate 5e-5 --max_train_epochs 4 --sdpa --highvram --cache_text_encoder_outputs_to_disk --cache_latents_to_disk --save_every_n_epochs 1
--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False"
--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0
--blockwise_fused_optimizer --double_blocks_to_swap 6 --cpu_offload_checkpointing
```

(Combine the command into one line.)

Options are almost the same as LoRA training. The difference is `--blockwise_fused_optimizer`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available.

`--blockwise_fused_optimizer` enables the fusing of the optimizer for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--fused_optimizer_groups` is deprecated due to the addition of this option for FLUX.1 training.

`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--blockwise_fused_optimizer`.

`--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage.

All these options are experimental and may change in the future.

The increasing the number of blocks to swap may reduce the memory usage, but the training speed will be slower. `--cpu_offload_checkpointing` also slows down the training.

Swap 6 double blocks and use cpu offload checkpointing may be a good starting point. Please try different settings according to VRAM usage and training speed.

The learning rate and the number of epochs are not optimized yet. Please adjust them according to the training results.

### Merge LoRA to FLUX.1 checkpoint

`networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint. __The script is experimental.__
Expand Down Expand Up @@ -298,7 +346,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser

- Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr!
- The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower.
- Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only AdaFactor is supported. Gradient accumulation is not available.
- Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only Adafactor is supported. Gradient accumulation is not available.
- Setting mixed precision to `no` seems to use less memory than `fp16` or `bf16`.
- Training is possible with a memory usage of about 17GB with a batch size of 1 and fp32. If you specify the `--full_bf16` option, you can further reduce the memory usage (but the accuracy will be lower). With the same memory usage as before, you can increase the batch size.
- PyTorch 2.1 or later is required because it uses the new API `Tensor.register_post_accumulate_grad_hook(hook)`.
Expand All @@ -308,7 +356,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
- Memory usage is reduced by the same principle as Fused optimizer. The training results and speed are the same as Fused optimizer.
- Specify the number of groups like `--fused_optimizer_groups 10` in `sdxl_train.py`. Increasing the number of groups reduces memory usage but slows down training. Since the effect is limited to a certain number, it is recommended to specify 4-10.
- Any optimizer can be used, but optimizers that automatically calculate the learning rate (such as D-Adaptation and Prodigy) cannot be used. Gradient accumulation is not available.
- `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using AdaFactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required.
- `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using Adafactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required.
- Mechanism: While Fused optimizer performs backward/step for individual parameters within the optimizer, optimizer groups reduce memory usage by grouping parameters and creating multiple optimizers to perform backward/step for each group. Fused optimizer requires implementation on the optimizer side, while optimizer groups are implemented only on the training script side.

- LoRA+ is supported. PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) Thanks to rockerBOO!
Expand Down Expand Up @@ -361,7 +409,7 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821!

- SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。
- optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。
- `sdxl_train.py``--fused_backward_pass` オプションを指定してください。現時点では optimizer は AdaFactor のみ対応しています。また gradient accumulation は使えません。
- `sdxl_train.py``--fused_backward_pass` オプションを指定してください。現時点では optimizer は Adafactor のみ対応しています。また gradient accumulation は使えません。
- mixed precision は `no` のほうが `fp16``bf16` よりも使用メモリ量が少ないようです。
- バッチサイズ 1、fp32 で 17GB 程度で学習可能なようです。`--full_bf16` オプションを指定するとさらに削減できます(精度は劣ります)。以前と同じメモリ使用量ではバッチサイズを増やせます。
- PyTorch 2.1 以降の新 API `Tensor.register_post_accumulate_grad_hook(hook)` を使用しているため、PyTorch 2.1 以降が必要です。
Expand Down
Loading

0 comments on commit ef535ec

Please sign in to comment.