From 2889108d858880589d362e06e98eeadf4682476a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 5 Sep 2024 20:58:33 +0900 Subject: [PATCH] feat: Add --cpu_offload_checkpointing option to LoRA training --- README.md | 7 +++++++ flux_train.py | 2 +- flux_train_network.py | 5 +++++ train_network.py | 12 +++++++++++- 4 files changed, 24 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index fa81f6c0f..e8a12089f 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,12 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 5, 2024 (update 1): + +Added `--cpu_offload_checkpointing` option to LoRA training script. Offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`. + Sep 5, 2024: + The LoRA merge script now supports CLIP-L and T5XXL LoRA. Please specify `--clip_l` and `--t5xxl`. `--clip_l_save_to` and `--t5xxl_save_to` specify the save destination for CLIP-L and T5XXL. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details. Sep 4, 2024: @@ -72,6 +77,8 @@ The training can be done with 12GB VRAM GPUs with Adafactor optimizer, `--split_ --optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --split_mode --network_args "train_blocks=single" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 ``` +`--cpu_offload_checkpointing` offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`. + We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. The trained LoRA model can be used with ComfyUI. diff --git a/flux_train.py b/flux_train.py index 0293b7be3..0edc83a9f 100644 --- a/flux_train.py +++ b/flux_train.py @@ -261,7 +261,7 @@ def train(args): ) if args.gradient_checkpointing: - flux.enable_gradient_checkpointing(args.cpu_offload_checkpointing) + flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) flux.requires_grad_(True) diff --git a/flux_train_network.py b/flux_train_network.py index 2fc0f3234..a6e57eede 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -50,6 +50,11 @@ def assert_extra_args(self, args, train_dataset_group): if args.max_token_length is not None: logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + assert not args.split_mode or not args.cpu_offload_checkpointing, ( + "split_mode and cpu_offload_checkpointing cannot be used together" + " / split_modeとcpu_offload_checkpointingは同時に使用できません" + ) + train_dataset_group.verify_bucket_reso_steps(32) # TODO check this def get_flux_model_name(self, args): diff --git a/train_network.py b/train_network.py index a68ccfcc4..ad97491df 100644 --- a/train_network.py +++ b/train_network.py @@ -451,7 +451,11 @@ def train(self, args): accelerator.print(f"load network weights from {args.network_weights}: {info}") if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() + if args.cpu_offload_checkpointing: + unet.enable_gradient_checkpointing(cpu_offload=True) + else: + unet.enable_gradient_checkpointing() + for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)): if flag: if t_enc.supports_gradient_checkpointing: @@ -1281,6 +1285,12 @@ def setup_parser() -> argparse.ArgumentParser: config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) + parser.add_argument( + "--cpu_offload_checkpointing", + action="store_true", + help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing for U-Net or DiT, if supported" + " / 勾配チェックポイント時にテンソルをCPUにオフロードする(U-NetまたはDiTのみ、サポートされている場合)", + ) parser.add_argument( "--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない" )