From ef535ec6bb99918027afc1e31efa72cd3761d453 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 18 Aug 2024 16:54:18 +0900 Subject: [PATCH] add memory efficient training for FLUX.1 --- README.md | 64 ++++++++++++-- flux_train.py | 187 +++++++++++++++++++++++++++++------------ library/flux_models.py | 182 ++++++++++++++++++++++++++++++++++----- 3 files changed, 354 insertions(+), 79 deletions(-) diff --git a/README.md b/README.md index 2b7b110f3..521e82e86 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,11 @@ The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` -Aug 17. 2024: +Aug 18, 2024: +Memory-efficient training based on 2kpr's implementation is implemented in `flux_train.py`. Thanks to 2kpr. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. + + +Aug 17, 2024: Added a script `flux_train.py` to train FLUX.1. The script is experimental and not an optimized version. It needs >28GB VRAM for training. Aug 16, 2024: @@ -39,11 +43,23 @@ Aug 11, 2024: Fix `--apply_t5_attn_mask` option to work. Please remove and re-ge Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI. -We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below, settings are based on [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit). It will work with 24GB VRAM GPUs. + +### FLUX.1 LoRA training + +We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below, settings are based on [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit). It will work with 24GB VRAM GPUs. ``` -accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 --loss_type l2 +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py +--pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors +--ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers +--max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 +--network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 +--network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base +--highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml +--output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigmoid +--model_prediction_type raw --guidance_scale 1.0 --loss_type l2 ``` +(The command is multi-line for readability. Please combine it into one line.) The training can be done with 16GB VRAM GPUs with Adafactor optimizer. Please use settings like below: @@ -80,12 +96,44 @@ The trained LoRA model can be used with ComfyUI. The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options. -Aug 12: `--interactive` option is now working. - ``` python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0 ``` +### FLUX.1 fine-tuning + +Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GPUs, and 64GB main memory is recommended. + +``` +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train.py +--pretrained_model_name_or_path flux1-dev.sft --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.sft +--mixed_precision bf16 --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 +--seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 +--dataset_config dataset_1024_bs1.toml --output_dir path/to/output/dir --output_name test-bf16 +--learning_rate 5e-5 --max_train_epochs 4 --sdpa --highvram --cache_text_encoder_outputs_to_disk --cache_latents_to_disk --save_every_n_epochs 1 +--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" +--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 +--blockwise_fused_optimizer --double_blocks_to_swap 6 --cpu_offload_checkpointing +``` + +(Combine the command into one line.) + +Options are almost the same as LoRA training. The difference is `--blockwise_fused_optimizer`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available. + +`--blockwise_fused_optimizer` enables the fusing of the optimizer for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--fused_optimizer_groups` is deprecated due to the addition of this option for FLUX.1 training. + +`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--blockwise_fused_optimizer`. + +`--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage. + +All these options are experimental and may change in the future. + +The increasing the number of blocks to swap may reduce the memory usage, but the training speed will be slower. `--cpu_offload_checkpointing` also slows down the training. + +Swap 6 double blocks and use cpu offload checkpointing may be a good starting point. Please try different settings according to VRAM usage and training speed. + +The learning rate and the number of epochs are not optimized yet. Please adjust them according to the training results. + ### Merge LoRA to FLUX.1 checkpoint `networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint. __The script is experimental.__ @@ -298,7 +346,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr! - The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower. - - Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only AdaFactor is supported. Gradient accumulation is not available. + - Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only Adafactor is supported. Gradient accumulation is not available. - Setting mixed precision to `no` seems to use less memory than `fp16` or `bf16`. - Training is possible with a memory usage of about 17GB with a batch size of 1 and fp32. If you specify the `--full_bf16` option, you can further reduce the memory usage (but the accuracy will be lower). With the same memory usage as before, you can increase the batch size. - PyTorch 2.1 or later is required because it uses the new API `Tensor.register_post_accumulate_grad_hook(hook)`. @@ -308,7 +356,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Memory usage is reduced by the same principle as Fused optimizer. The training results and speed are the same as Fused optimizer. - Specify the number of groups like `--fused_optimizer_groups 10` in `sdxl_train.py`. Increasing the number of groups reduces memory usage but slows down training. Since the effect is limited to a certain number, it is recommended to specify 4-10. - Any optimizer can be used, but optimizers that automatically calculate the learning rate (such as D-Adaptation and Prodigy) cannot be used. Gradient accumulation is not available. - - `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using AdaFactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required. + - `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using Adafactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required. - Mechanism: While Fused optimizer performs backward/step for individual parameters within the optimizer, optimizer groups reduce memory usage by grouping parameters and creating multiple optimizers to perform backward/step for each group. Fused optimizer requires implementation on the optimizer side, while optimizer groups are implemented only on the training script side. - LoRA+ is supported. PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) Thanks to rockerBOO! @@ -361,7 +409,7 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! - SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。 - optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。 - - `sdxl_train.py` に `--fused_backward_pass` オプションを指定してください。現時点では optimizer は AdaFactor のみ対応しています。また gradient accumulation は使えません。 + - `sdxl_train.py` に `--fused_backward_pass` オプションを指定してください。現時点では optimizer は Adafactor のみ対応しています。また gradient accumulation は使えません。 - mixed precision は `no` のほうが `fp16` や `bf16` よりも使用メモリ量が少ないようです。 - バッチサイズ 1、fp32 で 17GB 程度で学習可能なようです。`--full_bf16` オプションを指定するとさらに削減できます(精度は劣ります)。以前と同じメモリ使用量ではバッチサイズを増やせます。 - PyTorch 2.1 以降の新 API `Tensor.register_post_accumulate_grad_hook(hook)` を使用しているため、PyTorch 2.1 以降が必要です。 diff --git a/flux_train.py b/flux_train.py index d2a9b3f32..ecb3c7dda 100644 --- a/flux_train.py +++ b/flux_train.py @@ -1,5 +1,15 @@ # training with captions +# Swap blocks between CPU and GPU: +# This implementation is inspired by and based on the work of 2kpr. +# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading. +# The original idea has been adapted and extended to fit the current project's needs. + +# Key features: +# - CPU offloading during forward and backward passes +# - Use of fused optimizer and grad_hook for efficient gradient processing +# - Per-block fused optimizer instances + import argparse import copy import math @@ -54,6 +64,12 @@ def train(args): ) args.cache_text_encoder_outputs = True + if args.cpu_offload_checkpointing and not args.gradient_checkpointing: + logger.warning( + "cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります" + ) + args.gradient_checkpointing = True + cache_latents = args.cache_latents use_dreambooth_method = args.in_json is None @@ -232,16 +248,25 @@ def train(args): # now we can delete Text Encoders to free memory clip_l = None t5xxl = None + clean_memory_on_device(accelerator.device) # load FLUX # if we load to cpu, flux.to(fp8) takes a long time flux = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") if args.gradient_checkpointing: - flux.enable_gradient_checkpointing() + flux.enable_gradient_checkpointing(args.cpu_offload_checkpointing) flux.requires_grad_(True) + if args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + # This idea is based on 2kpr's great work. Thank you! + logger.info( + f"enable block swap: double_blocks_to_swap={args.double_blocks_to_swap}, single_blocks_to_swap={args.single_blocks_to_swap}" + ) + flux.enable_block_swap(args.double_blocks_to_swap, args.single_blocks_to_swap) + if not cache_latents: # load VAE here if not cached ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu") @@ -265,40 +290,43 @@ def train(args): # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - if args.fused_optimizer_groups: + if args.blockwise_fused_optimizers: # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html - # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters. + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters. # This balances memory usage and management complexity. - # calculate total number of parameters - n_total_params = sum(len(params["params"]) for params in params_to_optimize) - params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups) - - # split params into groups, keeping the learning rate the same for all params in a group - # this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders) + # split params into groups. currently different learning rates are not supported grouped_params = [] - param_group = [] - param_group_lr = -1 + param_group = {} for group in params_to_optimize: - lr = group["lr"] - for p in group["params"]: - # if the learning rate is different for different params, start a new group - if lr != param_group_lr: - if param_group: - grouped_params.append({"params": param_group, "lr": param_group_lr}) - param_group = [] - param_group_lr = lr - - param_group.append(p) - - # if the group has enough parameters, start a new group - if len(param_group) == params_per_group: - grouped_params.append({"params": param_group, "lr": param_group_lr}) - param_group = [] - param_group_lr = -1 - - if param_group: - grouped_params.append({"params": param_group, "lr": param_group_lr}) + named_parameters = list(flux.named_parameters()) + assert len(named_parameters) == len(group["params"]), "number of parameters does not match" + for p, np in zip(group["params"], named_parameters): + # determine target layer and block index for each parameter + block_type = "other" # double, single or other + if np[0].startswith("double_blocks"): + block_idx = int(np[0].split(".")[1]) + block_type = "double" + elif np[0].startswith("single_blocks"): + block_idx = int(np[0].split(".")[1]) + block_type = "single" + else: + block_idx = -1 + + param_group_key = (block_type, block_idx) + if param_group_key not in param_group: + param_group[param_group_key] = [] + param_group[param_group_key].append(p) + + block_types_and_indices = [] + for param_group_key, param_group in param_group.items(): + block_types_and_indices.append(param_group_key) + grouped_params.append({"params": param_group, "lr": args.learning_rate}) + + num_params = 0 + for p in param_group: + num_params += p.numel() + accelerator.print(f"block {param_group_key}: {num_params} parameters") # prepare optimizers for each group optimizers = [] @@ -307,7 +335,7 @@ def train(args): optimizers.append(optimizer) optimizer = optimizers[0] # avoid error in the following code - logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups") + logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers") else: _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) @@ -341,7 +369,7 @@ def train(args): train_dataset_group.set_max_train_steps(args.max_train_steps) # lr schedulerを用意する - if args.fused_optimizer_groups: + if args.blockwise_fused_optimizers: # prepare lr schedulers for each optimizer lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] lr_scheduler = lr_schedulers[0] # avoid error in the following code @@ -414,7 +442,7 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): parameter.register_post_accumulate_grad_hook(__grad_hook) - elif args.fused_optimizer_groups: + elif args.blockwise_fused_optimizers: # prepare for additional optimizers and lr schedulers for i in range(1, len(optimizers)): optimizers[i] = accelerator.prepare(optimizers[i]) @@ -429,22 +457,46 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): num_parameters_per_group = [0] * len(optimizers) parameter_optimizer_map = {} + double_blocks_to_swap = args.double_blocks_to_swap + single_blocks_to_swap = args.single_blocks_to_swap + num_double_blocks = len(flux.double_blocks) + num_single_blocks = len(flux.single_blocks) + for opt_idx, optimizer in enumerate(optimizers): for param_group in optimizer.param_groups: for parameter in param_group["params"]: if parameter.requires_grad: - - def optimizer_hook(parameter: torch.Tensor): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(parameter, args.max_grad_norm) - - i = parameter_optimizer_map[parameter] - optimizer_hooked_count[i] += 1 - if optimizer_hooked_count[i] == num_parameters_per_group[i]: - optimizers[i].step() - optimizers[i].zero_grad(set_to_none=True) - - parameter.register_post_accumulate_grad_hook(optimizer_hook) + block_type, block_idx = block_types_and_indices[opt_idx] + + def create_optimizer_hook(btype, bidx): + def optimizer_hook(parameter: torch.Tensor): + # print(f"optimizer_hook: {btype}, {bidx}") + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + # swap blocks if necessary + if btype == "double" and double_blocks_to_swap: + if bidx >= num_double_blocks - double_blocks_to_swap: + bidx_cuda = double_blocks_to_swap - (num_double_blocks - bidx) + flux.double_blocks[bidx].to("cpu") + flux.double_blocks[bidx_cuda].to(accelerator.device) + # print(f"Move double block {bidx} to cpu and {bidx_cuda} to device") + elif btype == "single" and single_blocks_to_swap: + if bidx >= num_single_blocks - single_blocks_to_swap: + bidx_cuda = single_blocks_to_swap - (num_single_blocks - bidx) + flux.single_blocks[bidx].to("cpu") + flux.single_blocks[bidx_cuda].to(accelerator.device) + # print(f"Move single block {bidx} to cpu and {bidx_cuda} to device") + + return optimizer_hook + + parameter.register_post_accumulate_grad_hook(create_optimizer_hook(block_type, block_idx)) parameter_optimizer_map[parameter] = opt_idx num_parameters_per_group[opt_idx] += 1 @@ -487,6 +539,9 @@ def optimizer_hook(parameter: torch.Tensor): init_kwargs=init_kwargs, ) + if args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None: + flux.prepare_block_swap_before_forward() + # For --sample_at_first flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) @@ -502,7 +557,7 @@ def optimizer_hook(parameter: torch.Tensor): for step, batch in enumerate(train_dataloader): current_step.value = global_step - if args.fused_optimizer_groups: + if args.blockwise_fused_optimizers: optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step with accelerator.accumulate(*training_models): @@ -591,7 +646,7 @@ def optimizer_hook(parameter: torch.Tensor): # backward accelerator.backward(loss) - if not (args.fused_backward_pass or args.fused_optimizer_groups): + if not (args.fused_backward_pass or args.blockwise_fused_optimizers): if accelerator.sync_gradients and args.max_grad_norm != 0.0: params_to_clip = [] for m in training_models: @@ -604,7 +659,7 @@ def optimizer_hook(parameter: torch.Tensor): else: # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook lr_scheduler.step() - if args.fused_optimizer_groups: + if args.blockwise_fused_optimizers: for i in range(1, len(optimizers)): lr_schedulers[i].step() @@ -614,7 +669,7 @@ def optimizer_hook(parameter: torch.Tensor): global_step += 1 flux_train_utils.sample_images( - accelerator, args, epoch, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs + accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs ) # 指定ステップごとにモデルを保存 @@ -673,8 +728,6 @@ def optimizer_hook(parameter: torch.Tensor): is_main_process = accelerator.is_main_process # if is_main_process: flux = accelerator.unwrap_model(flux) - clip_l = accelerator.unwrap_model(clip_l) - t5xxl = accelerator.unwrap_model(t5xxl) accelerator.end_training() @@ -707,13 +760,43 @@ def setup_parser() -> argparse.ArgumentParser: "--fused_optimizer_groups", type=int, default=None, - help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", + help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます", + ) + parser.add_argument( + "--blockwise_fused_optimizers", + action="store_true", + help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする", ) parser.add_argument( "--skip_latents_validity_check", action="store_true", help="skip latents validity check / latentsの正当性チェックをスキップする", ) + parser.add_argument( + "--double_blocks_to_swap", + type=int, + default=None, + help="[EXPERIMENTAL] " + "Sets the number of 'double_blocks' (~640MB) to swap during the forward and backward passes." + "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." + " / 順伝播および逆伝播中にスワップする'変換ブロック'(約640MB)の数を設定します。" + "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", + ) + parser.add_argument( + "--single_blocks_to_swap", + type=int, + default=None, + help="[EXPERIMENTAL] " + "Sets the number of 'single_blocks' (~320MB) to swap during the forward and backward passes." + "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." + " / 順伝播および逆伝播中にスワップする'変換ブロック'(約320MB)の数を設定します。" + "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", + ) + parser.add_argument( + "--cpu_offload_checkpointing", + action="store_true", + help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", + ) return parser diff --git a/library/flux_models.py b/library/flux_models.py index ed0bc8c7d..3f44068f9 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -4,6 +4,11 @@ from dataclasses import dataclass import math +from typing import Optional + +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() import torch from einops import rearrange @@ -466,6 +471,33 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tenso # region layers + + +# for cpu_offload_checkpointing + + +def to_cuda(x): + if isinstance(x, torch.Tensor): + return x.cuda() + elif isinstance(x, (list, tuple)): + return [to_cuda(elem) for elem in x] + elif isinstance(x, dict): + return {k: to_cuda(v) for k, v in x.items()} + else: + return x + + +def to_cpu(x): + if isinstance(x, torch.Tensor): + return x.cpu() + elif isinstance(x, (list, tuple)): + return [to_cpu(elem) for elem in x] + elif isinstance(x, dict): + return {k: to_cpu(v) for k, v in x.items()} + else: + return x + + class EmbedND(nn.Module): def __init__(self, dim: int, theta: int, axes_dim: list[int]): super().__init__() @@ -648,16 +680,15 @@ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: ) self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False - def enable_gradient_checkpointing(self): + def enable_gradient_checkpointing(self, cpu_offload: bool = False): self.gradient_checkpointing = True - # self.img_attn.enable_gradient_checkpointing() - # self.txt_attn.enable_gradient_checkpointing() + self.cpu_offload_checkpointing = cpu_offload def disable_gradient_checkpointing(self): self.gradient_checkpointing = False - # self.img_attn.disable_gradient_checkpointing() - # self.txt_attn.disable_gradient_checkpointing() + self.cpu_offload_checkpointing = False def _forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: img_mod1, img_mod2 = self.img_mod(vec) @@ -694,11 +725,24 @@ def _forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[T txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) return img, txt - def forward(self, *args, **kwargs): + def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: if self.training and self.gradient_checkpointing: - return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + if not self.cpu_offload_checkpointing: + return checkpoint(self._forward, img, txt, vec, pe, use_reentrant=False) + # cpu offload checkpointing + + def create_custom_forward(func): + def custom_forward(*inputs): + cuda_inputs = to_cuda(inputs) + outputs = func(*cuda_inputs) + return to_cpu(outputs) + + return custom_forward + + return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), img, txt, vec, pe) + else: - return self._forward(*args, **kwargs) + return self._forward(img, txt, vec, pe) # def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor): # if self.training and self.gradient_checkpointing: @@ -747,12 +791,15 @@ def __init__( self.modulation = Modulation(hidden_size, double=False) self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False - def enable_gradient_checkpointing(self): + def enable_gradient_checkpointing(self, cpu_offload: bool = False): self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload def disable_gradient_checkpointing(self): self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False def _forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: mod, _ = self.modulation(vec) @@ -768,11 +815,24 @@ def _forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) return x + mod.gate * output - def forward(self, *args, **kwargs): + def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: if self.training and self.gradient_checkpointing: - return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + if not self.cpu_offload_checkpointing: + return checkpoint(self._forward, x, vec, pe, use_reentrant=False) + + # cpu offload checkpointing + + def create_custom_forward(func): + def custom_forward(*inputs): + cuda_inputs = to_cuda(inputs) + outputs = func(*cuda_inputs) + return to_cpu(outputs) + + return custom_forward + + return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe) else: - return self._forward(*args, **kwargs) + return self._forward(x, vec, pe) # def forward(self, x: Tensor, vec: Tensor, pe: Tensor): # if self.training and self.gradient_checkpointing: @@ -849,6 +909,9 @@ def __init__(self, params: FluxParams): self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + self.double_blocks_to_swap = None + self.single_blocks_to_swap = None @property def device(self): @@ -858,8 +921,9 @@ def device(self): def dtype(self): return next(self.parameters()).dtype - def enable_gradient_checkpointing(self): + def enable_gradient_checkpointing(self, cpu_offload: bool = False): self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload self.time_in.enable_gradient_checkpointing() self.vector_in.enable_gradient_checkpointing() @@ -867,12 +931,13 @@ def enable_gradient_checkpointing(self): self.guidance_in.enable_gradient_checkpointing() for block in self.double_blocks + self.single_blocks: - block.enable_gradient_checkpointing() + block.enable_gradient_checkpointing(cpu_offload=cpu_offload) - print("FLUX: Gradient checkpointing enabled.") + print(f"FLUX: Gradient checkpointing enabled. CPU offload: {cpu_offload}") def disable_gradient_checkpointing(self): self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False self.time_in.disable_gradient_checkpointing() self.vector_in.disable_gradient_checkpointing() @@ -884,6 +949,24 @@ def disable_gradient_checkpointing(self): print("FLUX: Gradient checkpointing disabled.") + def enable_block_swap(self, double_blocks: Optional[int], single_blocks: Optional[int]): + self.double_blocks_to_swap = double_blocks + self.single_blocks_to_swap = single_blocks + + def prepare_block_swap_before_forward(self): + # move last n blocks to cpu: they are on cuda + if self.double_blocks_to_swap: + for i in range(len(self.double_blocks) - self.double_blocks_to_swap): + self.double_blocks[i].to(self.device) + for i in range(len(self.double_blocks) - self.double_blocks_to_swap, len(self.double_blocks)): + self.double_blocks[i].to("cpu") # , non_blocking=True) + if self.single_blocks_to_swap: + for i in range(len(self.single_blocks) - self.single_blocks_to_swap): + self.single_blocks[i].to(self.device) + for i in range(len(self.single_blocks) - self.single_blocks_to_swap, len(self.single_blocks)): + self.single_blocks[i].to("cpu") # , non_blocking=True) + clean_memory_on_device(self.device) + def forward( self, img: Tensor, @@ -910,14 +993,75 @@ def forward( ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) - for block in self.double_blocks: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + if not self.double_blocks_to_swap: + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + else: + # make sure first n blocks are on cuda, and last n blocks are on cpu at beginning + for block_idx in range(self.double_blocks_to_swap): + block = self.double_blocks[len(self.double_blocks) - self.double_blocks_to_swap + block_idx] + if block.parameters().__next__().device.type != "cpu": + block.to("cpu") # , non_blocking=True) + # print(f"Moved double block {len(self.double_blocks) - self.double_blocks_to_swap + block_idx} to cpu.") + + block = self.double_blocks[block_idx] + if block.parameters().__next__().device.type == "cpu": + block.to(self.device) + # print(f"Moved double block {block_idx} to cuda.") + + to_cpu_block_index = 0 + for block_idx, block in enumerate(self.double_blocks): + # move last n blocks to cuda: they are on cpu, and move first n blocks to cpu: they are on cuda + moving = block_idx >= len(self.double_blocks) - self.double_blocks_to_swap + if moving: + block.to(self.device) # move to cuda + # print(f"Moved double block {block_idx} to cuda.") + + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + + if moving: + self.double_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True) + # print(f"Moved double block {to_cpu_block_index} to cpu.") + to_cpu_block_index += 1 img = torch.cat((txt, img), 1) - for block in self.single_blocks: - img = block(img, vec=vec, pe=pe) + + if not self.single_blocks_to_swap: + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe) + else: + # make sure first n blocks are on cuda, and last n blocks are on cpu at beginning + for block_idx in range(self.single_blocks_to_swap): + block = self.single_blocks[len(self.single_blocks) - self.single_blocks_to_swap + block_idx] + if block.parameters().__next__().device.type != "cpu": + block.to("cpu") # , non_blocking=True) + # print(f"Moved single block {len(self.single_blocks) - self.single_blocks_to_swap + block_idx} to cpu.") + + block = self.single_blocks[block_idx] + if block.parameters().__next__().device.type == "cpu": + block.to(self.device) + # print(f"Moved single block {block_idx} to cuda.") + + to_cpu_block_index = 0 + for block_idx, block in enumerate(self.single_blocks): + # move last n blocks to cuda: they are on cpu, and move first n blocks to cpu: they are on cuda + moving = block_idx >= len(self.single_blocks) - self.single_blocks_to_swap + if moving: + block.to(self.device) # move to cuda + # print(f"Moved single block {block_idx} to cuda.") + + img = block(img, vec=vec, pe=pe) + + if moving: + self.single_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True) + # print(f"Moved single block {to_cpu_block_index} to cpu.") + img = img[:, txt.shape[1] :, ...] + if self.training and self.cpu_offload_checkpointing: + img = img.to(self.device) + vec = vec.to(self.device) + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img