Skip to content

Commit

Permalink
feat: Add --cpu_offload_checkpointing option to LoRA training
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Sep 5, 2024
1 parent d912952 commit 2889108
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 2 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@ The command to install PyTorch is as follows:

### Recent Updates

Sep 5, 2024 (update 1):

Added `--cpu_offload_checkpointing` option to LoRA training script. Offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`.

Sep 5, 2024:

The LoRA merge script now supports CLIP-L and T5XXL LoRA. Please specify `--clip_l` and `--t5xxl`. `--clip_l_save_to` and `--t5xxl_save_to` specify the save destination for CLIP-L and T5XXL. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details.

Sep 4, 2024:
Expand Down Expand Up @@ -72,6 +77,8 @@ The training can be done with 12GB VRAM GPUs with Adafactor optimizer, `--split_
--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --split_mode --network_args "train_blocks=single" --lr_scheduler constant_with_warmup --max_grad_norm 0.0
```

`--cpu_offload_checkpointing` offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`.

We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted.

The trained LoRA model can be used with ComfyUI.
Expand Down
2 changes: 1 addition & 1 deletion flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def train(args):
)

if args.gradient_checkpointing:
flux.enable_gradient_checkpointing(args.cpu_offload_checkpointing)
flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing)

flux.requires_grad_(True)

Expand Down
5 changes: 5 additions & 0 deletions flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def assert_extra_args(self, args, train_dataset_group):
if args.max_token_length is not None:
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")

assert not args.split_mode or not args.cpu_offload_checkpointing, (
"split_mode and cpu_offload_checkpointing cannot be used together"
" / split_modeとcpu_offload_checkpointingは同時に使用できません"
)

train_dataset_group.verify_bucket_reso_steps(32) # TODO check this

def get_flux_model_name(self, args):
Expand Down
12 changes: 11 additions & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,11 @@ def train(self, args):
accelerator.print(f"load network weights from {args.network_weights}: {info}")

if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if args.cpu_offload_checkpointing:
unet.enable_gradient_checkpointing(cpu_offload=True)
else:
unet.enable_gradient_checkpointing()

for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)):
if flag:
if t_enc.supports_gradient_checkpointing:
Expand Down Expand Up @@ -1281,6 +1285,12 @@ def setup_parser() -> argparse.ArgumentParser:
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)

parser.add_argument(
"--cpu_offload_checkpointing",
action="store_true",
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing for U-Net or DiT, if supported"
" / 勾配チェックポイント時にテンソルをCPUにオフロードする(U-NetまたはDiTのみ、サポートされている場合)",
)
parser.add_argument(
"--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない"
)
Expand Down

0 comments on commit 2889108

Please sign in to comment.