From 4e56a20b677ca04b3f0d600ce1299cab3bcdc266 Mon Sep 17 00:00:00 2001 From: Joey Overby <38049004+JoeyOverby@users.noreply.github.com> Date: Mon, 16 Sep 2024 10:36:22 -0600 Subject: [PATCH] Update submodules as well since MPS changes were needed --- README.md | 170 ++- fine_tune.py | 20 +- gen_img.py | 55 +- library/.DS_Store | Bin 0 -> 6148 bytes library/adafactor_fused.py | 106 ++ library/config_util.py | 7 + library/custom_train_functions.py | 14 +- library/ipex/attention.py | 2 +- library/sai_model_spec.py | 22 +- library/sd3_models.py | 2031 ++++++++++++++++++++++++++ library/sd3_train_utils.py | 656 +++++++++ library/sd3_utils.py | 513 +++++++ library/sdxl_model_util.py | 16 +- library/sdxl_train_util.py | 13 +- library/train_util.py | 562 ++++++- requirements.txt | 2 +- sd3_minimal_inference.py | 351 +++++ sd3_train.py | 981 +++++++++++++ sdxl_train.py | 172 ++- sdxl_train_control_net_lllite.py | 16 +- sdxl_train_control_net_lllite_old.py | 2 +- train_controlnet.py | 42 +- train_db.py | 4 +- train_network.py | 207 ++- train_textual_inversion.py | 4 +- train_textual_inversion_XTI.py | 4 +- 26 files changed, 5785 insertions(+), 187 deletions(-) create mode 100644 library/.DS_Store create mode 100644 library/adafactor_fused.py create mode 100644 library/sd3_models.py create mode 100644 library/sd3_train_utils.py create mode 100644 library/sd3_utils.py create mode 100644 sd3_minimal_inference.py create mode 100644 sd3_train.py diff --git a/README.md b/README.md index 0be2f9a70..5d4f9621d 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,39 @@ This repository contains training, generation and utility scripts for Stable Diffusion. +## SD3 training + +SD3 training is done with `sd3_train.py`. + +__Jul 11, 2024__: Fixed to work t5xxl with `fp16`. If you change the dtype to `fp16` for t5xxl, please remove existing latents cache files (`*_sd3.npz`). The shift in `sd3_minimum_inference.py` is fixed to 3.0. Thanks to araleza! + +Jun 29, 2024: Fixed mixed precision training with fp16 is not working. Fixed the model is in bf16 dtype even without `--full_bf16` option (this could worsen the training result). + +`fp16` and `bf16` are available for mixed precision training. We are not sure which is better. + +`optimizer_type = "adafactor"` is recommended for 24GB VRAM GPUs. `cache_text_encoder_outputs_to_disk` and `cache_latents_to_disk` are necessary currently. + +`clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them. + +~~t5xxl doesn't seem to work with `fp16`, so 1) use`bf16` for mixed precision, or 2) use `bf16` or `float32` for `t5xxl_dtype`. ~~ t5xxl works with `fp16` now. + +There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype. + +`text_encoder_batch_size` is added experimentally for caching faster. + +```toml +learning_rate = 1e-6 # seems to depend on the batch size +optimizer_type = "adafactor" +optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ] +cache_text_encoder_outputs = true +cache_text_encoder_outputs_to_disk = true +vae_batch_size = 1 +text_encoder_batch_size = 4 +cache_latents = true +cache_latents_to_disk = true +``` + +--- + [__Change History__](#change-history) is moved to the bottom of the page. 更新履歴は[ページ末尾](#change-history)に移しました。 @@ -137,19 +171,133 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History -### Sep 13, 2024 / 2024-09-13: +### Working in progress + +- Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr! + - The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower. + - Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only AdaFactor is supported. Gradient accumulation is not available. + - Setting mixed precision to `no` seems to use less memory than `fp16` or `bf16`. + - Training is possible with a memory usage of about 17GB with a batch size of 1 and fp32. If you specify the `--full_bf16` option, you can further reduce the memory usage (but the accuracy will be lower). With the same memory usage as before, you can increase the batch size. + - PyTorch 2.1 or later is required because it uses the new API `Tensor.register_post_accumulate_grad_hook(hook)`. + - Mechanism: Normally, backward -> step is performed for each parameter, so all gradients need to be temporarily stored in memory. "Fuse backward and step" reduces memory usage by performing backward/step for each parameter and reflecting the gradient immediately. The more parameters there are, the greater the effect, so it is not effective in other training scripts (LoRA, etc.) where the memory usage peak is elsewhere, and there are no plans to implement it in those training scripts. + +- Optimizer groups feature is added to SDXL training. PR [#1319](https://github.com/kohya-ss/sd-scripts/pull/1319) + - Memory usage is reduced by the same principle as Fused optimizer. The training results and speed are the same as Fused optimizer. + - Specify the number of groups like `--fused_optimizer_groups 10` in `sdxl_train.py`. Increasing the number of groups reduces memory usage but slows down training. Since the effect is limited to a certain number, it is recommended to specify 4-10. + - Any optimizer can be used, but optimizers that automatically calculate the learning rate (such as D-Adaptation and Prodigy) cannot be used. Gradient accumulation is not available. + - `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using AdaFactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required. + - Mechanism: While Fused optimizer performs backward/step for individual parameters within the optimizer, optimizer groups reduce memory usage by grouping parameters and creating multiple optimizers to perform backward/step for each group. Fused optimizer requires implementation on the optimizer side, while optimizer groups are implemented only on the training script side. + +- LoRA+ is supported. PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) Thanks to rockerBOO! + - LoRA+ is a method to improve training speed by increasing the learning rate of the UP side (LoRA-B) of LoRA. Specify the multiple. The original paper recommends 16, but adjust as needed. Please see the PR for details. + - Specify `loraplus_lr_ratio` with `--network_args`. Example: `--network_args "loraplus_lr_ratio=16"` + - `loraplus_unet_lr_ratio` and `loraplus_lr_ratio` can be specified separately for U-Net and Text Encoder. + - Example: `--network_args "loraplus_unet_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` or `--network_args "loraplus_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` etc. + - `network_module` `networks.lora` and `networks.dylora` are available. + +- The feature to use the transparency (alpha channel) of the image as a mask in the loss calculation has been added. PR [#1223](https://github.com/kohya-ss/sd-scripts/pull/1223) Thanks to u-haru! + - The transparent part is ignored during training. Specify the `--alpha_mask` option in the training script or specify `alpha_mask = true` in the dataset configuration file. + - See [About masked loss](./docs/masked_loss_README.md) for details. + +- LoRA training in SDXL now supports block-wise learning rates and block-wise dim (rank). PR [#1331](https://github.com/kohya-ss/sd-scripts/pull/1331) + - Specify the learning rate and dim (rank) for each block. + - See [Block-wise learning rates in LoRA](./docs/train_network_README-ja.md#階層別学習率) for details (Japanese only). + +- Negative learning rates can now be specified during SDXL model training. PR [#1277](https://github.com/kohya-ss/sd-scripts/pull/1277) Thanks to Cauldrath! + - The model is trained to move away from the training images, so the model is easily collapsed. Use with caution. A value close to 0 is recommended. + - When specifying from the command line, use `=` like `--learning_rate=-1e-7`. + +- Training scripts can now output training settings to wandb or Tensor Board logs. Specify the `--log_config` option. PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) Thanks to ccharest93, plucked, rockerBOO, and VelocityRa! + - Some settings, such as API keys and directory specifications, are not output due to security issues. + +- The ControlNet training script `train_controlnet.py` for SD1.5/2.x was not working, but it has been fixed. PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) Thanks to sdbds! + +- `train_network.py` and `sdxl_train_network.py` now restore the order/position of data loading from DataSet when resuming training. PR [#1353](https://github.com/kohya-ss/sd-scripts/pull/1353) [#1359](https://github.com/kohya-ss/sd-scripts/pull/1359) Thanks to KohakuBlueleaf! + - This resolves the issue where the order of data loading from DataSet changes when resuming training. + - Specify the `--skip_until_initial_step` option to skip data loading until the specified step. If not specified, data loading starts from the beginning of the DataSet (same as before). + - If `--resume` is specified, the step saved in the state is used. + - Specify the `--initial_step` or `--initial_epoch` option to skip data loading until the specified step or epoch. Use these options in conjunction with `--skip_until_initial_step`. These options can be used without `--resume` (use them when resuming training with `--network_weights`). + +- An option `--disable_mmap_load_safetensors` is added to disable memory mapping when loading the model's .safetensors in SDXL. PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Thanks to Zovjsra! + - It seems that the model file loading is faster in the WSL environment etc. + - Available in `sdxl_train.py`, `sdxl_train_network.py`, `sdxl_train_textual_inversion.py`, and `sdxl_train_control_net_lllite.py`. + +- When there is an error in the cached latents file on disk, the file name is now displayed. PR [#1278](https://github.com/kohya-ss/sd-scripts/pull/1278) Thanks to Cauldrath! + +- Fixed an error that occurs when specifying `--max_dataloader_n_workers` in `tag_images_by_wd14_tagger.py` when Onnx is not used. PR [#1291]( +https://github.com/kohya-ss/sd-scripts/pull/1291) issue [#1290]( +https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! + +- Fixed a bug that `caption_separator` cannot be specified in the subset in the dataset settings .toml file. [#1312](https://github.com/kohya-ss/sd-scripts/pull/1312) and [#1313](https://github.com/kohya-ss/sd-scripts/pull/1312) Thanks to rockerBOO! + +- Fixed a potential bug in ControlNet-LLLite training. PR [#1322](https://github.com/kohya-ss/sd-scripts/pull/1322) Thanks to aria1th! + +- Fixed some bugs when using DeepSpeed. Related [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) + +- Added a prompt option `--f` to `gen_imgs.py` to specify the file name when saving. Also, Diffusers-based keys for LoRA weights are now supported. + +- SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。 + - optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。 + - `sdxl_train.py` に `--fused_backward_pass` オプションを指定してください。現時点では optimizer は AdaFactor のみ対応しています。また gradient accumulation は使えません。 + - mixed precision は `no` のほうが `fp16` や `bf16` よりも使用メモリ量が少ないようです。 + - バッチサイズ 1、fp32 で 17GB 程度で学習可能なようです。`--full_bf16` オプションを指定するとさらに削減できます(精度は劣ります)。以前と同じメモリ使用量ではバッチサイズを増やせます。 + - PyTorch 2.1 以降の新 API `Tensor.register_post_accumulate_grad_hook(hook)` を使用しているため、PyTorch 2.1 以降が必要です。 + - 仕組み:通常は backward -> step の順で行うためすべての勾配を一時的にメモリに保持する必要があります。「backward と step の統合」はパラメータごとに backward/step を行って、勾配をすぐ反映することでメモリ使用量を削減します。パラメータ数が多いほど効果が大きいため、SDXL の学習以外(LoRA 等)ではほぼ効果がなく(メモリ使用量のピークが他の場所にあるため)、それらの学習スクリプトへの実装予定もありません。 + +- SDXL の学習時に optimizer group 機能を追加しました。PR [#1319](https://github.com/kohya-ss/sd-scripts/pull/1319) + - Fused optimizer と同様の原理でメモリ使用量を削減します。学習結果や速度についても同様です。 + - `sdxl_train.py` に `--fused_optimizer_groups 10` のようにグループ数を指定してください。グループ数を増やすとメモリ使用量が削減されますが、速度は遅くなります。ある程度の数までしか効果がないため、4~10 程度を指定すると良いでしょう。 + - 任意の optimizer が使えますが、学習率を自動計算する optimizer (D-Adaptation や Prodigy など)は使えません。gradient accumulation は使えません。 + - `--fused_optimizer_groups` は `--fused_backward_pass` と併用できません。AdaFactor 使用時は Fused optimizer よりも若干メモリ使用量は大きくなります。PyTorch 2.1 以降が必要です。 + - 仕組み:Fused optimizer が optimizer 内で個別のパラメータについて backward/step を行っているのに対して、optimizer groups はパラメータをグループ化して複数の optimizer を作成し、それぞれ backward/step を行うことでメモリ使用量を削減します。Fused optimizer は optimizer 側の実装が必要ですが、optimizer groups は学習スクリプト側のみで実装されています。やはり SDXL の学習でのみ効果があります。 + +- LoRA+ がサポートされました。PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) rockerBOO 氏に感謝します。 + - LoRA の UP 側(LoRA-B)の学習率を上げることで学習速度の向上を図る手法です。倍数で指定します。元の論文では 16 が推奨されていますが、データセット等にもよりますので、適宜調整してください。PR もあわせてご覧ください。 + - `--network_args` で `loraplus_lr_ratio` を指定します。例:`--network_args "loraplus_lr_ratio=16"` + - `loraplus_unet_lr_ratio` と `loraplus_lr_ratio` で、U-Net および Text Encoder に個別の値を指定することも可能です。 + - 例:`--network_args "loraplus_unet_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` または `--network_args "loraplus_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` など + - `network_module` の `networks.lora` および `networks.dylora` で使用可能です。 + +- 画像の透明度(アルファチャネル)をロス計算時のマスクとして使用する機能が追加されました。PR [#1223](https://github.com/kohya-ss/sd-scripts/pull/1223) u-haru 氏に感謝します。 + - 透明部分が学習時に無視されるようになります。学習スクリプトに `--alpha_mask` オプションを指定するか、データセット設定ファイルに `alpha_mask = true` を指定してください。 + - 詳細は [マスクロスについて](./docs/masked_loss_README-ja.md) をご覧ください。 + +- SDXL の LoRA で階層別学習率、階層別 dim (rank) をサポートしました。PR [#1331](https://github.com/kohya-ss/sd-scripts/pull/1331) + - ブロックごとに学習率および dim (rank) を指定することができます。 + - 詳細は [LoRA の階層別学習率](./docs/train_network_README-ja.md#階層別学習率) をご覧ください。 + +- `sdxl_train.py` での SDXL モデル学習時に負の学習率が指定できるようになりました。PR [#1277](https://github.com/kohya-ss/sd-scripts/pull/1277) Cauldrath 氏に感謝します。 + - 学習画像から離れるように学習するため、モデルは容易に崩壊します。注意して使用してください。0 に近い値を推奨します。 + - コマンドラインから指定する場合、`--learning_rate=-1e-7` のように`=` を使ってください。 + +- 各学習スクリプトで学習設定を wandb や Tensor Board などのログに出力できるようになりました。`--log_config` オプションを指定してください。PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) ccharest93 氏、plucked 氏、rockerBOO 氏および VelocityRa 氏に感謝します。 + - API キーや各種ディレクトリ指定など、一部の設定はセキュリティ上の問題があるため出力されません。 + +- SD1.5/2.x 用の ControlNet 学習スクリプト `train_controlnet.py` が動作しなくなっていたのが修正されました。PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) sdbds 氏に感謝します。 + +- `train_network.py` および `sdxl_train_network.py` で、学習再開時に DataSet の読み込み順についても復元できるようになりました。PR [#1353](https://github.com/kohya-ss/sd-scripts/pull/1353) [#1359](https://github.com/kohya-ss/sd-scripts/pull/1359) KohakuBlueleaf 氏に感謝します。 + - これにより、学習再開時に DataSet の読み込み順が変わってしまう問題が解消されます。 + - `--skip_until_initial_step` オプションを指定すると、指定したステップまで DataSet 読み込みをスキップします。指定しない場合の動作は変わりません(DataSet の最初から読み込みます) + - `--resume` オプションを指定すると、state に保存されたステップ数が使用されます。 + - `--initial_step` または `--initial_epoch` オプションを指定すると、指定したステップまたはエポックまで DataSet 読み込みをスキップします。これらのオプションは `--skip_until_initial_step` と併用してください。またこれらのオプションは `--resume` と併用しなくても使えます(`--network_weights` を用いた学習再開時などにお使いください )。 + +- SDXL でモデルの .safetensors を読み込む際にメモリマッピングを無効化するオプション `--disable_mmap_load_safetensors` が追加されました。PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Zovjsra 氏に感謝します。 + - WSL 環境等でモデルファイルの読み込みが高速化されるようです。 + - `sdxl_train.py`、`sdxl_train_network.py`、`sdxl_train_textual_inversion.py`、`sdxl_train_control_net_lllite.py` で使用可能です。 + +- ディスクにキャッシュされた latents ファイルに何らかのエラーがあったとき、そのファイル名が表示されるようになりました。 PR [#1278](https://github.com/kohya-ss/sd-scripts/pull/1278) Cauldrath 氏に感謝します。 + +- `tag_images_by_wd14_tagger.py` で Onnx 未使用時に `--max_dataloader_n_workers` を指定するとエラーになる不具合が修正されました。 PR [#1291]( +https://github.com/kohya-ss/sd-scripts/pull/1291) issue [#1290]( +https://github.com/kohya-ss/sd-scripts/pull/1290) frodo821 氏に感謝します。 + +- データセット設定の .toml ファイルで、`caption_separator` が subset に指定できない不具合が修正されました。 PR [#1312](https://github.com/kohya-ss/sd-scripts/pull/1312) および [#1313](https://github.com/kohya-ss/sd-scripts/pull/1313) rockerBOO 氏に感謝します。 + +- ControlNet-LLLite 学習時の潜在バグが修正されました。 PR [#1322](https://github.com/kohya-ss/sd-scripts/pull/1322) aria1th 氏に感謝します。 -- `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580). -- `svd_merge_lora.py` now supports LBW. Thanks to terracottahaniwa. See PR [#1575](https://github.com/kohya-ss/sd-scripts/pull/1575) for details. -- `sdxl_merge_lora.py` also supports LBW. -- See [LoRA Block Weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) by hako-mikan for details on LBW. -- These will be included in the next release. +- DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) -- `sdxl_merge_lora.py` が OFT をサポートされました。PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580) Maru-mee 氏に感謝します。 -- `svd_merge_lora.py` で LBW がサポートされました。PR [#1575](https://github.com/kohya-ss/sd-scripts/pull/1575) terracottahaniwa 氏に感謝します。 -- `sdxl_merge_lora.py` でも LBW がサポートされました。 -- LBW の詳細は hako-mikan 氏の [LoRA Block Weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) をご覧ください。 -- 以上は次回リリースに含まれます。 +- `gen_imgs.py` のプロンプトオプションに、保存時のファイル名を指定する `--f` オプションを追加しました。また同スクリプトで Diffusers ベースのキーを持つ LoRA の重みに対応しました。 ### Jun 23, 2024 / 2024-06-23: diff --git a/fine_tune.py b/fine_tune.py index c7e6bbd2e..d865cd2de 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -310,7 +310,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) # For --sample_at_first train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) @@ -354,7 +358,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) # Predict the noise residual with accelerator.autocast(): @@ -368,7 +374,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss: # do not mean over batch dimension for snr weight or scale v-pred loss - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: @@ -380,7 +388,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = loss.mean() # mean over batch dimension else: - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c + ) accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: @@ -471,7 +481,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.end_training() - if is_main_process and (args.save_state or args.save_state_on_train_end): + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す diff --git a/gen_img.py b/gen_img.py index 4fe898716..d0a8f8141 100644 --- a/gen_img.py +++ b/gen_img.py @@ -1435,6 +1435,7 @@ class BatchDataBase(NamedTuple): clip_prompt: str guide_image: Any raw_prompt: str + file_name: Optional[str] class BatchDataExt(NamedTuple): @@ -2316,7 +2317,7 @@ def scale_and_round(x): # このバッチの情報を取り出す ( return_latents, - (step_first, _, _, _, init_image, mask_image, _, guide_image, _), + (step_first, _, _, _, init_image, mask_image, _, guide_image, _, _), ( width, height, @@ -2339,6 +2340,7 @@ def scale_and_round(x): prompts = [] negative_prompts = [] raw_prompts = [] + filenames = [] start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) noises = [ torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) @@ -2371,7 +2373,7 @@ def scale_and_round(x): all_guide_images_are_same = True for i, ( _, - (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt), + (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt, filename), _, ) in enumerate(batch): prompts.append(prompt) @@ -2379,6 +2381,7 @@ def scale_and_round(x): seeds.append(seed) clip_prompts.append(clip_prompt) raw_prompts.append(raw_prompt) + filenames.append(filename) if init_image is not None: init_images.append(init_image) @@ -2478,8 +2481,8 @@ def scale_and_round(x): # save image highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( - zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts) + for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt, filename) in enumerate( + zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts, filenames) ): if highres_fix: seed -= 1 # record original seed @@ -2505,17 +2508,23 @@ def scale_and_round(x): metadata.add_text("crop-top", str(crop_top)) metadata.add_text("crop-left", str(crop_left)) - if args.use_original_file_name and init_images is not None: - if type(init_images) is list: - fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" - else: - fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" - elif args.sequential_file_name: - fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png" + if filename is not None: + fln = filename else: - fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" + if args.use_original_file_name and init_images is not None: + if type(init_images) is list: + fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" + else: + fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" + elif args.sequential_file_name: + fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png" + else: + fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" - image.save(os.path.join(args.outdir, fln), pnginfo=metadata) + if fln.endswith(".webp"): + image.save(os.path.join(args.outdir, fln), pnginfo=metadata, quality=100) # lossy + else: + image.save(os.path.join(args.outdir, fln), pnginfo=metadata) if not args.no_preview and not highres_1st and args.interactive: try: @@ -2562,6 +2571,7 @@ def scale_and_round(x): # repeat prompt for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] + filename = None if pi == 0 or len(raw_prompts) > 1: # parse prompt: if prompt is not changed, skip parsing @@ -2783,6 +2793,12 @@ def scale_and_round(x): logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") continue + m = re.match(r"f (.+)", parg, re.IGNORECASE) + if m: # filename + filename = m.group(1) + logger.info(f"filename: {filename}") + continue + except ValueError as ex: logger.error(f"Exception in parsing / 解析エラー: {parg}") logger.error(f"{ex}") @@ -2873,7 +2889,16 @@ def scale_and_round(x): b1 = BatchData( False, BatchDataBase( - global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt + global_step, + prompt, + negative_prompt, + seed, + init_image, + mask_image, + clip_prompt, + guide_image, + raw_prompt, + filename, ), BatchDataExt( width, @@ -2916,7 +2941,7 @@ def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() add_logging_arguments(parser) - + parser.add_argument( "--sdxl", action="store_true", help="load Stable Diffusion XL model / Stable Diffusion XLのモデルを読み込む" ) diff --git a/library/.DS_Store b/library/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..f5b92b35a1b6c5894af4d07b46c0548f8d08e035 GIT binary patch literal 6148 zcmeHKK~BR!4D^;r1ohG*$31bc_&}&qapKem1e#V9sZnx4f?GZT@8AU$C&Vjw0e^t8 zy`+iSBK3k$WlP@Kc*xDea%y6Wa-GGt@y@!AzYh~*z!)$F z4u%2LY?ipMXst0|3>X6q1N?vRP{t%+tth_^G;#$1Mlg$@FJ~p7X9F+^SS!K;aT*HL zP#3NkPQzjM!7m9|D{447SsC+$m0h@@{G8OB7k6@rqP50=G0 Optional[str]: def load_metadata_from_safetensors(model: str) -> dict: if not model.endswith(".safetensors"): return {} - + with safetensors.safe_open(model, framework="pt") as f: metadata = f.metadata() if metadata is None: diff --git a/library/sd3_models.py b/library/sd3_models.py new file mode 100644 index 000000000..a1ff1e75a --- /dev/null +++ b/library/sd3_models.py @@ -0,0 +1,2031 @@ +# some modules/classes are copied and modified from https://github.com/mcmonkey4eva/sd3-ref +# the original code is licensed under the MIT License + +# and some module/classes are contributed from KohakuBlueleaf. Thanks for the contribution! + +from ast import Tuple +from functools import partial +import math +from types import SimpleNamespace +from typing import Dict, List, Optional, Union +import einops +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint +from transformers import CLIPTokenizer, T5TokenizerFast + + +memory_efficient_attention = None +try: + import xformers +except: + pass + +try: + from xformers.ops import memory_efficient_attention +except: + memory_efficient_attention = None + + +# region tokenizer +class SDTokenizer: + def __init__( + self, max_length=77, pad_with_end=True, tokenizer=None, has_start_token=True, pad_to_max_length=True, min_length=None + ): + """ + サブクラスで各種の設定を行ってる。このクラスはその設定に基づき重み付きのトークン化を行うようだ。 + Some settings are done in subclasses. This class seems to perform tokenization with weights based on those settings. + """ + self.tokenizer = tokenizer + self.max_length = max_length + self.min_length = min_length + empty = self.tokenizer("")["input_ids"] + if has_start_token: + self.tokens_start = 1 + self.start_token = empty[0] + self.end_token = empty[1] + else: + self.tokens_start = 0 + self.start_token = None + self.end_token = empty[0] + self.pad_with_end = pad_with_end + self.pad_to_max_length = pad_to_max_length + vocab = self.tokenizer.get_vocab() + self.inv_vocab = {v: k for k, v in vocab.items()} + self.max_word_length = 8 + + def tokenize_with_weights(self, text: str, truncate_to_max_length=True, truncate_length=None): + """Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. + The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.""" + """ + ja: テキストをトークン化し、重み値を持ちます - すべての値に1.0を仮定し、他の機能を無視します。 + 詳細は参考実装には関係なく、重み自体はSD3に対して弱い影響しかありません。へぇ~ + """ + if self.pad_with_end: + pad_token = self.end_token + else: + pad_token = 0 + batch = [] + if self.start_token is not None: + batch.append((self.start_token, 1.0)) + to_tokenize = text.replace("\n", " ").split(" ") + to_tokenize = [x for x in to_tokenize if x != ""] + for word in to_tokenize: + batch.extend([(t, 1) for t in self.tokenizer(word)["input_ids"][self.tokens_start : -1]]) + batch.append((self.end_token, 1.0)) + if self.pad_to_max_length: + batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch))) + if self.min_length is not None and len(batch) < self.min_length: + batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch))) + + # truncate to max_length + # print(f"batch: {batch}, truncate: {truncate}, len(batch): {len(batch)}, max_length: {self.max_length}") + if truncate_to_max_length and len(batch) > self.max_length: + batch = batch[: self.max_length] + if truncate_length is not None and len(batch) > truncate_length: + batch = batch[:truncate_length] + + return [batch] + + +class T5XXLTokenizer(SDTokenizer): + """Wraps the T5 Tokenizer from HF into the SDTokenizer interface""" + + def __init__(self): + super().__init__( + pad_with_end=False, + tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"), + has_start_token=False, + pad_to_max_length=False, + max_length=99999999, + min_length=77, + ) + + +class SDXLClipGTokenizer(SDTokenizer): + def __init__(self, tokenizer): + super().__init__(pad_with_end=False, tokenizer=tokenizer) + + +class SD3Tokenizer: + def __init__(self, t5xxl=True): + # TODO cache tokenizer settings locally or hold them in the repo like ComfyUI + clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + self.clip_l = SDTokenizer(tokenizer=clip_tokenizer) + self.clip_g = SDXLClipGTokenizer(clip_tokenizer) + self.t5xxl = T5XXLTokenizer() if t5xxl else None + # t5xxl has 99999999 max length, clip has 77 + self.model_max_length = self.clip_l.max_length # 77 + + def tokenize_with_weights(self, text: str): + # temporary truncate to max_length even for t5xxl + return ( + self.clip_l.tokenize_with_weights(text), + self.clip_g.tokenize_with_weights(text), + ( + self.t5xxl.tokenize_with_weights(text, truncate_to_max_length=False, truncate_length=self.model_max_length) + if self.t5xxl is not None + else None + ), + ) + + +# endregion + +# region mmdit + + +def get_2d_sincos_pos_embed( + embed_dim, + grid_size, + scaling_factor=None, + offset=None, +): + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + if scaling_factor is not None: + grid = grid / scaling_factor + if offset is not None: + grid = grid - offset + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid_torch( + embed_dim, + pos, + device=None, + dtype=torch.float32, +): + omega = torch.arange(embed_dim // 2, device=device, dtype=dtype) + omega *= 2.0 / embed_dim + omega = 1.0 / 10000**omega + out = torch.outer(pos.reshape(-1), omega) + emb = torch.cat([out.sin(), out.cos()], dim=1) + return emb + + +def get_2d_sincos_pos_embed_torch( + embed_dim, + w, + h, + val_center=7.5, + val_magnitude=7.5, + device=None, + dtype=torch.float32, +): + small = min(h, w) + val_h = (h / small) * val_magnitude + val_w = (w / small) * val_magnitude + grid_h, grid_w = torch.meshgrid( + torch.linspace(-val_h + val_center, val_h + val_center, h, device=device, dtype=dtype), + torch.linspace(-val_w + val_center, val_w + val_center, w, device=device, dtype=dtype), + indexing="ij", + ) + emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype) + emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype) + emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D) + return emb + + +def modulate(x, shift, scale): + if shift is None: + shift = torch.zeros_like(scale) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def default(x, default_value): + if x is None: + return default_value + return x + + +def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + # freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + # device=t.device, dtype=t.dtype + # ) + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(dtype=t.dtype) + return embedding + + +def rmsnorm(x, eps=1e-6): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + +class PatchEmbed(nn.Module): + def __init__( + self, + img_size=256, + patch_size=4, + in_channels=3, + embed_dim=512, + norm_layer=None, + flatten=True, + bias=True, + strict_img_size=True, + dynamic_img_pad=True, + ): + super().__init__() + self.patch_size = patch_size + self.flatten = flatten + self.strict_img_size = strict_img_size + self.dynamic_img_pad = dynamic_img_pad + if img_size is not None: + self.img_size = img_size + self.grid_size = img_size // patch_size + self.num_patches = self.grid_size**2 + else: + self.img_size = None + self.grid_size = None + self.num_patches = None + + self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, patch_size, bias=bias) + self.norm = nn.Identity() if norm_layer is None else norm_layer(embed_dim) + + def forward(self, x): + B, C, H, W = x.shape + + if self.dynamic_img_pad: + # Pad input so we won't have partial patch + pad_h = (self.patch_size - H % self.patch_size) % self.patch_size + pad_w = (self.patch_size - W % self.patch_size) % self.patch_size + x = nn.functional.pad(x, (0, pad_w, 0, pad_h), mode="reflect") + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + return x + + +# FinalLayer in mmdit.py +class UnPatch(nn.Module): + def __init__(self, hidden_size=512, patch_size=4, out_channels=3): + super().__init__() + self.patch_size = patch_size + self.c = out_channels + + # eps is default in mmdit.py + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size**2 * out_channels) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size), + ) + + def forward(self, x: torch.Tensor, cmod, H=None, W=None): + b, n, _ = x.shape + p = self.patch_size + c = self.c + if H is None and W is None: + w = h = int(n**0.5) + assert h * w == n + else: + h = H // p if H else n // (W // p) + w = W // p if W else n // h + assert h * w == n + + shift, scale = self.adaLN_modulation(cmod).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + + x = x.view(b, h, w, p, p, c) + x = x.permute(0, 5, 1, 3, 2, 4).contiguous() + x = x.view(b, c, h * p, w * p) + return x + + +class MLP(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=lambda: nn.GELU(), + norm_layer=None, + bias=True, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.use_conv = use_conv + + layer = partial(nn.Conv1d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = layer(in_features, hidden_features, bias=bias) + self.fc2 = layer(hidden_features, out_features, bias=bias) + self.act = act_layer() + self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.norm(x) + x = self.fc2(x) + return x + + +class TimestepEmbedding(nn.Module): + def __init__(self, hidden_size, freq_embed_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(freq_embed_size, hidden_size), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size), + ) + self.freq_embed_size = freq_embed_size + + def forward(self, t, dtype=None, **kwargs): + t_freq = timestep_embedding(t, self.freq_embed_size).to(dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class Embedder(nn.Module): + def __init__(self, input_dim, hidden_size): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(input_dim, hidden_size), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size), + ) + + def forward(self, x): + return self.mlp(x) + + +class RMSNorm(torch.nn.Module): + def __init__( + self, + dim: int, + elementwise_affine: bool = False, + eps: float = 1e-6, + device=None, + dtype=None, + ): + """ + Initialize the RMSNorm normalization layer. + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + """ + super().__init__() + self.eps = eps + self.learnable_scale = elementwise_affine + if self.learnable_scale: + self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) + else: + self.register_parameter("weight", None) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + Args: + x (torch.Tensor): The input tensor. + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + """ + x = rmsnorm(x, eps=self.eps) + if self.learnable_scale: + return x * self.weight.to(device=x.device, dtype=x.dtype) + else: + return x + + +class SwiGLUFeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: float = None, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) + + +# Linears for SelfAttention in mmdit.py +class AttentionLinears(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + pre_only: bool = False, + qk_norm: str = None, + ): + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + if not pre_only: + self.proj = nn.Linear(dim, dim) + self.pre_only = pre_only + + if qk_norm == "rms": + self.ln_q = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) + self.ln_k = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) + elif qk_norm == "ln": + self.ln_q = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) + self.ln_k = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) + elif qk_norm is None: + self.ln_q = nn.Identity() + self.ln_k = nn.Identity() + else: + raise ValueError(qk_norm) + + def pre_attention(self, x: torch.Tensor) -> torch.Tensor: + """ + output: + q, k, v: [B, L, D] + """ + B, L, C = x.shape + qkv: torch.Tensor = self.qkv(x) + q, k, v = qkv.reshape(B, L, -1, self.head_dim).chunk(3, dim=2) + q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1) + k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1) + return (q, k, v) + + def post_attention(self, x: torch.Tensor) -> torch.Tensor: + assert not self.pre_only + x = self.proj(x) + return x + + +MEMORY_LAYOUTS = { + "torch": ( + lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim).transpose(1, 2), + lambda x: x.transpose(1, 2).reshape(x.shape[0], x.shape[2], -1), + lambda x: (1, x, 1, 1), + ), + "xformers": ( + lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim), + lambda x: x.reshape(x.shape[0], x.shape[1], -1), + lambda x: (1, 1, x, 1), + ), + "math": ( + lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim).transpose(1, 2), + lambda x: x.transpose(1, 2).reshape(x.shape[0], x.shape[2], -1), + lambda x: (1, x, 1, 1), + ), +} +# ATTN_FUNCTION = { +# "torch": F.scaled_dot_product_attention, +# "xformers": memory_efficient_attention, +# } + + +def vanilla_attention(q, k, v, mask, scale=None): + if scale is None: + scale = math.sqrt(q.size(-1)) + scores = torch.bmm(q, k.transpose(-1, -2)) / scale + if mask is not None: + mask = einops.rearrange(mask, "b ... -> b (...)") + max_neg_value = -torch.finfo(scores.dtype).max + mask = einops.repeat(mask, "b j -> (b h) j", h=q.size(-3)) + scores = scores.masked_fill(~mask, max_neg_value) + p_attn = F.softmax(scores, dim=-1) + return torch.bmm(p_attn, v) + + +def attention(q, k, v, head_dim, mask=None, scale=None, mode="xformers"): + """ + q, k, v: [B, L, D] + """ + pre_attn_layout = MEMORY_LAYOUTS[mode][0] + post_attn_layout = MEMORY_LAYOUTS[mode][1] + q = pre_attn_layout(q, head_dim) + k = pre_attn_layout(k, head_dim) + v = pre_attn_layout(v, head_dim) + + # scores = ATTN_FUNCTION[mode](q, k.to(q), v.to(q), mask, scale=scale) + if mode == "torch": + assert scale is None + scores = F.scaled_dot_product_attention(q, k.to(q), v.to(q), mask) # , scale=scale) + elif mode == "xformers": + scores = memory_efficient_attention(q, k.to(q), v.to(q), mask, scale=scale) + else: + scores = vanilla_attention(q, k.to(q), v.to(q), mask, scale=scale) + + scores = post_attn_layout(scores) + return scores + + +class SelfAttention(AttentionLinears): + def __init__(self, dim, num_heads=8, mode="xformers"): + super().__init__(dim, num_heads, qkv_bias=True, pre_only=False) + assert mode in MEMORY_LAYOUTS + self.head_dim = dim // num_heads + self.attn_mode = mode + + def set_attn_mode(self, mode): + self.attn_mode = mode + + def forward(self, x): + q, k, v = self.pre_attention(x) + attn_score = attention(q, k, v, self.head_dim, mode=self.attn_mode) + return self.post_attention(attn_score) + + +class TransformerBlock(nn.Module): + def __init__(self, context_size, mode="xformers"): + super().__init__() + self.context_size = context_size + self.norm1 = nn.LayerNorm(context_size, elementwise_affine=False, eps=1e-6) + self.attn = SelfAttention(context_size, mode=mode) + self.norm2 = nn.LayerNorm(context_size, elementwise_affine=False, eps=1e-6) + self.mlp = MLP( + in_features=context_size, + hidden_features=context_size * 4, + act_layer=lambda: nn.GELU(approximate="tanh"), + ) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, context_size, num_layers, mode="xformers"): + super().__init__() + self.layers = nn.ModuleList([TransformerBlock(context_size, mode) for _ in range(num_layers)]) + self.norm = nn.LayerNorm(context_size, elementwise_affine=False, eps=1e-6) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return self.norm(x) + + +# DismantledBlock in mmdit.py +class SingleDiTBlock(nn.Module): + """ + A DiT block with gated adaptive layer norm (adaLN) conditioning. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: str = "xformers", + qkv_bias: bool = False, + pre_only: bool = False, + rmsnorm: bool = False, + scale_mod_only: bool = False, + swiglu: bool = False, + qk_norm: Optional[str] = None, + **block_kwargs, + ): + super().__init__() + assert attn_mode in MEMORY_LAYOUTS + self.attn_mode = attn_mode + if not rmsnorm: + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + else: + self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = AttentionLinears( + dim=hidden_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + pre_only=pre_only, + qk_norm=qk_norm, + ) + if not pre_only: + if not rmsnorm: + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + else: + self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + if not pre_only: + if not swiglu: + self.mlp = MLP( + in_features=hidden_size, + hidden_features=mlp_hidden_dim, + act_layer=lambda: nn.GELU(approximate="tanh"), + ) + else: + self.mlp = SwiGLUFeedForward( + dim=hidden_size, + hidden_dim=mlp_hidden_dim, + multiple_of=256, + ) + self.scale_mod_only = scale_mod_only + if not scale_mod_only: + n_mods = 6 if not pre_only else 2 + else: + n_mods = 4 if not pre_only else 1 + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, n_mods * hidden_size)) + self.pre_only = pre_only + + def pre_attention(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + if not self.pre_only: + if not self.scale_mod_only: + ( + shift_msa, + scale_msa, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + ) = self.adaLN_modulation( + c + ).chunk(6, dim=-1) + else: + shift_msa = None + shift_mlp = None + ( + scale_msa, + gate_msa, + scale_mlp, + gate_mlp, + ) = self.adaLN_modulation( + c + ).chunk(4, dim=-1) + qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) + return qkv, ( + x, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + ) + else: + if not self.scale_mod_only: + ( + shift_msa, + scale_msa, + ) = self.adaLN_modulation( + c + ).chunk(2, dim=-1) + else: + shift_msa = None + scale_msa = self.adaLN_modulation(c) + qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) + return qkv, None + + def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp): + assert not self.pre_only + x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +# JointBlock + block_mixing in mmdit.py +class MMDiTBlock(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + pre_only = kwargs.pop("pre_only") + self.context_block = SingleDiTBlock(*args, pre_only=pre_only, **kwargs) + self.x_block = SingleDiTBlock(*args, pre_only=False, **kwargs) + self.head_dim = self.x_block.attn.head_dim + self.mode = self.x_block.attn_mode + self.gradient_checkpointing = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def _forward(self, context, x, c): + ctx_qkv, ctx_intermediate = self.context_block.pre_attention(context, c) + x_qkv, x_intermediate = self.x_block.pre_attention(x, c) + + ctx_len = ctx_qkv[0].size(1) + + q = torch.concat((ctx_qkv[0], x_qkv[0]), dim=1) + k = torch.concat((ctx_qkv[1], x_qkv[1]), dim=1) + v = torch.concat((ctx_qkv[2], x_qkv[2]), dim=1) + + attn = attention(q, k, v, head_dim=self.head_dim, mode=self.mode) + ctx_attn_out = attn[:, :ctx_len] + x_attn_out = attn[:, ctx_len:] + + x = self.x_block.post_attention(x_attn_out, *x_intermediate) + if not self.context_block.pre_only: + context = self.context_block.post_attention(ctx_attn_out, *ctx_intermediate) + else: + context = None + return context, x + + def forward(self, *args, **kwargs): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + else: + return self._forward(*args, **kwargs) + + +class MMDiT(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + + def __init__( + self, + input_size: int = 32, + patch_size: int = 2, + in_channels: int = 4, + depth: int = 28, + # hidden_size: Optional[int] = None, + # num_heads: Optional[int] = None, + mlp_ratio: float = 4.0, + learn_sigma: bool = False, + adm_in_channels: Optional[int] = None, + context_embedder_config: Optional[Dict] = None, + use_checkpoint: bool = False, + register_length: int = 0, + attn_mode: str = "torch", + rmsnorm: bool = False, + scale_mod_only: bool = False, + swiglu: bool = False, + out_channels: Optional[int] = None, + pos_embed_scaling_factor: Optional[float] = None, + pos_embed_offset: Optional[float] = None, + pos_embed_max_size: Optional[int] = None, + num_patches=None, + qk_norm: Optional[str] = None, + qkv_bias: bool = True, + context_processor_layers=None, + context_size=4096, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + default_out_channels = in_channels * 2 if learn_sigma else in_channels + self.out_channels = default(out_channels, default_out_channels) + self.patch_size = patch_size + self.pos_embed_scaling_factor = pos_embed_scaling_factor + self.pos_embed_offset = pos_embed_offset + self.pos_embed_max_size = pos_embed_max_size + self.gradient_checkpointing = use_checkpoint + + # hidden_size = default(hidden_size, 64 * depth) + # num_heads = default(num_heads, hidden_size // 64) + + # apply magic --> this defines a head_size of 64 + self.hidden_size = 64 * depth + num_heads = depth + + self.num_heads = num_heads + + self.x_embedder = PatchEmbed( + input_size, + patch_size, + in_channels, + self.hidden_size, + bias=True, + strict_img_size=self.pos_embed_max_size is None, + ) + self.t_embedder = TimestepEmbedding(self.hidden_size) + + self.y_embedder = None + if adm_in_channels is not None: + assert isinstance(adm_in_channels, int) + self.y_embedder = Embedder(adm_in_channels, self.hidden_size) + + if context_processor_layers is not None: + self.context_processor = Transformer(context_size, context_processor_layers, attn_mode) + else: + self.context_processor = None + + self.context_embedder = nn.Linear(context_size, self.hidden_size) + self.register_length = register_length + if self.register_length > 0: + self.register = nn.Parameter(torch.randn(1, register_length, self.hidden_size)) + + # num_patches = self.x_embedder.num_patches + # Will use fixed sin-cos embedding: + # just use a buffer already + if num_patches is not None: + self.register_buffer( + "pos_embed", + torch.empty(1, num_patches, self.hidden_size), + ) + else: + self.pos_embed = None + + self.use_checkpoint = use_checkpoint + self.joint_blocks = nn.ModuleList( + [ + MMDiTBlock( + self.hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + qkv_bias=qkv_bias, + pre_only=i == depth - 1, + rmsnorm=rmsnorm, + scale_mod_only=scale_mod_only, + swiglu=swiglu, + qk_norm=qk_norm, + ) + for i in range(depth) + ] + ) + for block in self.joint_blocks: + block.gradient_checkpointing = use_checkpoint + + self.final_layer = UnPatch(self.hidden_size, patch_size, self.out_channels) + # self.initialize_weights() + + @property + def model_type(self): + return "m" # only support medium + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + for block in self.joint_blocks: + block.enable_gradient_checkpointing() + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + for block in self.joint_blocks: + block.disable_gradient_checkpointing() + + def initialize_weights(self): + # TODO: Init context_embedder? + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize (and freeze) pos_embed by sin-cos embedding + if self.pos_embed is not None: + pos_embed = get_2d_sincos_pos_embed( + self.pos_embed.shape[-1], + int(self.pos_embed.shape[-2] ** 0.5), + scaling_factor=self.pos_embed_scaling_factor, + ) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + if getattr(self, "y_embedder", None) is not None: + nn.init.normal_(self.y_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.y_embedder.mlp[2].weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + for block in self.joint_blocks: + nn.init.constant_(block.x_block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.x_block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(block.context_block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.context_block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def cropped_pos_embed(self, h, w, device=None): + p = self.x_embedder.patch_size + # patched size + h = (h + 1) // p + w = (w + 1) // p + if self.pos_embed is None: + return get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=device) + assert self.pos_embed_max_size is not None + assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size) + assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size) + top = (self.pos_embed_max_size - h) // 2 + left = (self.pos_embed_max_size - w) // 2 + spatial_pos_embed = self.pos_embed.reshape( + 1, + self.pos_embed_max_size, + self.pos_embed_max_size, + self.pos_embed.shape[-1], + ) + spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :] + spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) + return spatial_pos_embed + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, + y: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass of DiT. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N, D) tensor of class labels + """ + + if self.context_processor is not None: + context = self.context_processor(context) + + B, C, H, W = x.shape + x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device).to(dtype=x.dtype) + c = self.t_embedder(t, dtype=x.dtype) # (N, D) + if y is not None and self.y_embedder is not None: + y = self.y_embedder(y) # (N, D) + c = c + y # (N, D) + + if context is not None: + context = self.context_embedder(context) + + if self.register_length > 0: + context = torch.cat( + ( + einops.repeat(self.register, "1 ... -> b ...", b=x.shape[0]), + default(context, torch.Tensor([]).type_as(x)), + ), + 1, + ) + + for block in self.joint_blocks: + context, x = block(context, x, c) + x = self.final_layer(x, c, H, W) # Our final layer combined UnPatchify + return x[:, :, :H, :W] + + +def create_mmdit_sd3_medium_configs(attn_mode: str): + # {'patch_size': 2, 'depth': 24, 'num_patches': 36864, + # 'pos_embed_max_size': 192, 'adm_in_channels': 2048, 'context_embedder': + # {'target': 'torch.nn.Linear', 'params': {'in_features': 4096, 'out_features': 1536}}} + mmdit = MMDiT( + input_size=None, + pos_embed_max_size=192, + patch_size=2, + in_channels=16, + adm_in_channels=2048, + depth=24, + mlp_ratio=4, + qk_norm=None, + num_patches=36864, + context_size=4096, + attn_mode=attn_mode, + ) + return mmdit + + +# endregion + +# region VAE +# TODO support xformers + +VAE_SCALE_FACTOR = 1.5305 +VAE_SHIFT_FACTOR = 0.0609 + + +def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) + + +class ResnetBlock(torch.nn.Module): + def __init__(self, *, in_channels, out_channels=None, dtype=torch.float32, device=None): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = Normalize(in_channels, dtype=dtype, device=device) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + self.norm2 = Normalize(out_channels, dtype=dtype, device=device) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + if self.in_channels != self.out_channels: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device + ) + else: + self.nin_shortcut = None + self.swish = torch.nn.SiLU(inplace=True) + + def forward(self, x): + hidden = x + hidden = self.norm1(hidden) + hidden = self.swish(hidden) + hidden = self.conv1(hidden) + hidden = self.norm2(hidden) + hidden = self.swish(hidden) + hidden = self.conv2(hidden) + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + return x + hidden + + +class AttnBlock(torch.nn.Module): + def __init__(self, in_channels, dtype=torch.float32, device=None): + super().__init__() + self.norm = Normalize(in_channels, dtype=dtype, device=device) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + + def forward(self, x): + hidden = self.norm(x) + q = self.q(hidden) + k = self.k(hidden) + v = self.v(hidden) + b, c, h, w = q.shape + q, k, v = map(lambda x: einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)) + hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default + hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + hidden = self.proj_out(hidden) + return x + hidden + + +class Downsample(torch.nn.Module): + def __init__(self, in_channels, dtype=torch.float32, device=None): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0, dtype=dtype, device=device) + + def forward(self, x): + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(torch.nn.Module): + def __init__(self, in_channels, dtype=torch.float32, device=None): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + + def forward(self, x): + org_dtype = x.dtype + if x.dtype == torch.bfloat16: + x = x.to(torch.float32) + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if x.dtype != org_dtype: + x = x.to(org_dtype) + x = self.conv(x) + return x + + +class VAEEncoder(torch.nn.Module): + def __init__( + self, ch=128, ch_mult=(1, 2, 4, 4), num_res_blocks=2, in_channels=3, z_channels=16, dtype=torch.float32, device=None + ): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = torch.nn.ModuleList() + for i_level in range(self.num_resolutions): + block = torch.nn.ModuleList() + attn = torch.nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device)) + block_in = block_out + down = torch.nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, dtype=dtype, device=device) + self.down.append(down) + # middle + self.mid = torch.nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + # end + self.norm_out = Normalize(block_in, dtype=dtype, device=device) + self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + self.swish = torch.nn.SiLU(inplace=True) + + def forward(self, x): + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = self.swish(h) + h = self.conv_out(h) + return h + + +class VAEDecoder(torch.nn.Module): + def __init__( + self, + ch=128, + out_ch=3, + ch_mult=(1, 2, 4, 4), + num_res_blocks=2, + resolution=256, + z_channels=16, + dtype=torch.float32, + device=None, + ): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + # middle + self.mid = torch.nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + # upsampling + self.up = torch.nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = torch.nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device)) + block_in = block_out + up = torch.nn.Module() + up.block = block + if i_level != 0: + up.upsample = Upsample(block_in, dtype=dtype, device=device) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + # end + self.norm_out = Normalize(block_in, dtype=dtype, device=device) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + self.swish = torch.nn.SiLU(inplace=True) + + def forward(self, z): + # z to block_in + hidden = self.conv_in(z) + # middle + hidden = self.mid.block_1(hidden) + hidden = self.mid.attn_1(hidden) + hidden = self.mid.block_2(hidden) + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + hidden = self.up[i_level].block[i_block](hidden) + if i_level != 0: + hidden = self.up[i_level].upsample(hidden) + # end + hidden = self.norm_out(hidden) + hidden = self.swish(hidden) + hidden = self.conv_out(hidden) + return hidden + + +class SDVAE(torch.nn.Module): + def __init__(self, dtype=torch.float32, device=None): + super().__init__() + self.encoder = VAEEncoder(dtype=dtype, device=device) + self.decoder = VAEDecoder(dtype=dtype, device=device) + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + # @torch.autocast("cuda", dtype=torch.float16) + def decode(self, latent): + return self.decoder(latent) + + # @torch.autocast("cuda", dtype=torch.float16) + def encode(self, image): + hidden = self.encoder(image) + mean, logvar = torch.chunk(hidden, 2, dim=1) + logvar = torch.clamp(logvar, -30.0, 20.0) + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + + @staticmethod + def process_in(latent): + return (latent - VAE_SHIFT_FACTOR) * VAE_SCALE_FACTOR + + @staticmethod + def process_out(latent): + return (latent / VAE_SCALE_FACTOR) + VAE_SHIFT_FACTOR + + +class VAEOutput: + def __init__(self, latent): + self.latent = latent + + @property + def latent_dist(self): + return self + + def sample(self): + return self.latent + + +class VAEWrapper: + def __init__(self, vae): + self.vae = vae + + @property + def device(self): + return self.vae.device + + @property + def dtype(self): + return self.vae.dtype + + # latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") + def encode(self, image): + return VAEOutput(self.vae.encode(image)) + + +# endregion + + +# region Text Encoder +class CLIPAttention(torch.nn.Module): + def __init__(self, embed_dim, heads, dtype, device, mode="xformers"): + super().__init__() + self.heads = heads + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.attn_mode = mode + + def set_attn_mode(self, mode): + self.attn_mode = mode + + def forward(self, x, mask=None): + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + out = attention(q, k, v, self.heads, mask, mode=self.attn_mode) + return self.out_proj(out) + + +ACTIVATIONS = { + "quick_gelu": lambda: (lambda a: a * torch.sigmoid(1.702 * a)), + # "gelu": torch.nn.functional.gelu, + "gelu": lambda: nn.GELU(), +} + + +class CLIPLayer(torch.nn.Module): + def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device): + super().__init__() + self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + self.self_attn = CLIPAttention(embed_dim, heads, dtype, device) + self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + # # self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device) + # self.mlp = Mlp( + # embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation], dtype=dtype, device=device + # ) + self.mlp = MLP(embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation]) + self.mlp.to(device=device, dtype=dtype) + + def forward(self, x, mask=None): + x += self.self_attn(self.layer_norm1(x), mask) + x += self.mlp(self.layer_norm2(x)) + return x + + +class CLIPEncoder(torch.nn.Module): + def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device): + super().__init__() + self.layers = torch.nn.ModuleList( + [CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) for i in range(num_layers)] + ) + + def forward(self, x, mask=None, intermediate_output=None): + if intermediate_output is not None: + if intermediate_output < 0: + intermediate_output = len(self.layers) + intermediate_output + intermediate = None + for i, l in enumerate(self.layers): + x = l(x, mask) + if i == intermediate_output: + intermediate = x.clone() + return x, intermediate + + +class CLIPEmbeddings(torch.nn.Module): + def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None): + super().__init__() + self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device) + self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device) + + def forward(self, input_tokens): + return self.token_embedding(input_tokens) + self.position_embedding.weight + + +class CLIPTextModel_(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + num_layers = config_dict["num_hidden_layers"] + embed_dim = config_dict["hidden_size"] + heads = config_dict["num_attention_heads"] + intermediate_size = config_dict["intermediate_size"] + intermediate_activation = config_dict["hidden_act"] + super().__init__() + self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device) + self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) + self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + + def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True): + x = self.embeddings(input_tokens) + + if x.dtype == torch.bfloat16: + causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=torch.float32, device=x.device).fill_(float("-inf")).triu_(1) + causal_mask = causal_mask.to(dtype=x.dtype) + else: + causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) + + x, i = self.encoder(x, mask=causal_mask, intermediate_output=intermediate_output) + x = self.final_layer_norm(x) + if i is not None and final_layer_norm_intermediate: + i = self.final_layer_norm(i) + pooled_output = x[ + torch.arange(x.shape[0], device=x.device), + input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1), + ] + return x, i, pooled_output + + +class CLIPTextModel(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + super().__init__() + self.num_layers = config_dict["num_hidden_layers"] + self.text_model = CLIPTextModel_(config_dict, dtype, device) + embed_dim = config_dict["hidden_size"] + self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device) + self.text_projection.weight.copy_(torch.eye(embed_dim)) + self.dtype = dtype + + def get_input_embeddings(self): + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, embeddings): + self.text_model.embeddings.token_embedding = embeddings + + def forward(self, *args, **kwargs): + x = self.text_model(*args, **kwargs) + out = self.text_projection(x[2]) + return (x[0], x[1], out, x[2]) + + +class ClipTokenWeightEncoder: + # def encode_token_weights(self, token_weight_pairs): + # tokens = list(map(lambda a: a[0], token_weight_pairs[0])) + # out, pooled = self([tokens]) + # if pooled is not None: + # first_pooled = pooled[0:1] + # else: + # first_pooled = pooled + # output = [out[0:1]] + # return torch.cat(output, dim=-2), first_pooled + + # fix to support batched inputs + # : Union[List[Tuple[torch.Tensor, torch.Tensor]], List[List[Tuple[torch.Tensor, torch.Tensor]]]] + def encode_token_weights(self, list_of_token_weight_pairs): + has_batch = isinstance(list_of_token_weight_pairs[0][0], list) + + if has_batch: + list_of_tokens = [] + for pairs in list_of_token_weight_pairs: + tokens = [a[0] for a in pairs[0]] # I'm not sure why this is [0] + list_of_tokens.append(tokens) + else: + list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]] + + out, pooled = self(list_of_tokens) + if has_batch: + return out, pooled + else: + if pooled is not None: + first_pooled = pooled[0:1] + else: + first_pooled = pooled + output = [out[0:1]] + return torch.cat(output, dim=-2), first_pooled + + +class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + + LAYERS = ["last", "pooled", "hidden"] + + def __init__( + self, + device="cpu", + max_length=77, + layer="last", + layer_idx=None, + textmodel_json_config=None, + dtype=None, + model_class=CLIPTextModel, + special_tokens={"start": 49406, "end": 49407, "pad": 49407}, + layer_norm_hidden_state=True, + return_projected_pooled=True, + ): + super().__init__() + assert layer in self.LAYERS + self.transformer = model_class(textmodel_json_config, dtype, device) + self.num_layers = self.transformer.num_layers + self.max_length = max_length + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + self.layer = layer + self.layer_idx = None + self.special_tokens = special_tokens + self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) + self.layer_norm_hidden_state = layer_norm_hidden_state + self.return_projected_pooled = return_projected_pooled + if layer == "hidden": + assert layer_idx is not None + assert abs(layer_idx) < self.num_layers + self.set_clip_options({"layer": layer_idx}) + self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled) + + def set_attn_mode(self, mode): + raise NotImplementedError("This model does not support setting the attention mode") + + def set_clip_options(self, options): + layer_idx = options.get("layer", self.layer_idx) + self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) + if layer_idx is None or abs(layer_idx) > self.num_layers: + self.layer = "last" + else: + self.layer = "hidden" + self.layer_idx = layer_idx + + def forward(self, tokens): + backup_embeds = self.transformer.get_input_embeddings() + device = backup_embeds.weight.device + tokens = torch.LongTensor(tokens).to(device) + outputs = self.transformer( + tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state + ) + self.transformer.set_input_embeddings(backup_embeds) + if self.layer == "last": + z = outputs[0] + else: + z = outputs[1] + pooled_output = None + if len(outputs) >= 3: + if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None: + pooled_output = outputs[3].float() + elif outputs[2] is not None: + pooled_output = outputs[2].float() + return z.float(), pooled_output + + def set_attn_mode(self, mode): + clip_text_model = self.transformer.text_model + for layer in clip_text_model.encoder.layers: + layer.self_attn.set_attn_mode(mode) + + +class SDXLClipG(SDClipModel): + """Wraps the CLIP-G model into the SD-CLIP-Model interface""" + + def __init__(self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None): + if layer == "penultimate": + layer = "hidden" + layer_idx = -2 + super().__init__( + device=device, + layer=layer, + layer_idx=layer_idx, + textmodel_json_config=config, + dtype=dtype, + special_tokens={"start": 49406, "end": 49407, "pad": 0}, + layer_norm_hidden_state=False, + ) + + def set_attn_mode(self, mode): + clip_text_model = self.transformer.text_model + for layer in clip_text_model.encoder.layers: + layer.self_attn.set_attn_mode(mode) + + +class T5XXLModel(SDClipModel): + """Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience""" + + def __init__(self, config, device="cpu", layer="last", layer_idx=None, dtype=None): + super().__init__( + device=device, + layer=layer, + layer_idx=layer_idx, + textmodel_json_config=config, + dtype=dtype, + special_tokens={"end": 1, "pad": 0}, + model_class=T5, + ) + + def set_attn_mode(self, mode): + t5: T5 = self.transformer + for t5block in t5.encoder.block: + t5block: T5Block + t5layer: T5LayerSelfAttention = t5block.layer[0] + t5SaSa: T5Attention = t5layer.SelfAttention + t5SaSa.set_attn_mode(mode) + + +################################################################################################# +### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl +################################################################################################# + + +class T5XXLTokenizer(SDTokenizer): + """Wraps the T5 Tokenizer from HF into the SDTokenizer interface""" + + def __init__(self): + super().__init__( + pad_with_end=False, + tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"), + has_start_token=False, + pad_to_max_length=False, + max_length=99999999, + min_length=77, + ) + + +class T5LayerNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device)) + self.variance_epsilon = eps + + # def forward(self, x): + # variance = x.pow(2).mean(-1, keepdim=True) + # x = x * torch.rsqrt(variance + self.variance_epsilon) + # return self.weight.to(device=x.device, dtype=x.dtype) * x + + # copy from transformers' T5LayerNorm + def forward(self, hidden_states): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +class T5DenseGatedActDense(torch.nn.Module): + def __init__(self, model_dim, ff_dim, dtype, device): + super().__init__() + self.wi_0 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) + self.wi_1 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) + self.wo = nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device) + + def forward(self, x): + hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh") + hidden_linear = self.wi_1(x) + x = hidden_gelu * hidden_linear + x = self.wo(x) + return x + + +class T5LayerFF(torch.nn.Module): + def __init__(self, model_dim, ff_dim, dtype, device): + super().__init__() + self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device) + self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) + + def forward(self, x): + forwarded_states = self.layer_norm(x) + forwarded_states = self.DenseReluDense(forwarded_states) + x += forwarded_states + return x + + +class T5Attention(torch.nn.Module): + def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device): + super().__init__() + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.k = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.v = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.o = nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device) + self.num_heads = num_heads + self.relative_attention_bias = None + if relative_attention_bias: + self.relative_attention_num_buckets = 32 + self.relative_attention_max_distance = 128 + self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device) + + self.attn_mode = "xformers" # TODO 何とかする + + def set_attn_mode(self, mode): + self.attn_mode = mode + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length, device): + """Compute binned relative position bias""" + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=True, + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + def forward(self, x, past_bias=None): + q = self.q(x) + k = self.k(x) + v = self.v(x) + if self.relative_attention_bias is not None: + past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device) + if past_bias is not None: + mask = past_bias + out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask, mode=self.attn_mode) + return self.o(out), past_bias + + +class T5LayerSelfAttention(torch.nn.Module): + def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device): + super().__init__() + self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device) + self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) + + def forward(self, x, past_bias=None): + output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias) + x += output + return x, past_bias + + +class T5Block(torch.nn.Module): + def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device): + super().__init__() + self.layer = torch.nn.ModuleList() + self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device)) + self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device)) + + def forward(self, x, past_bias=None): + x, past_bias = self.layer[0](x, past_bias) + + # copy from transformers' T5Block + # clamp inf values to enable fp16 training + if x.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(x).any(), + torch.finfo(x.dtype).max - 1000, + torch.finfo(x.dtype).max, + ) + x = torch.clamp(x, min=-clamp_value, max=clamp_value) + + x = self.layer[-1](x) + # clamp inf values to enable fp16 training + if x.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(x).any(), + torch.finfo(x.dtype).max - 1000, + torch.finfo(x.dtype).max, + ) + x = torch.clamp(x, min=-clamp_value, max=clamp_value) + + return x, past_bias + + +class T5Stack(torch.nn.Module): + def __init__(self, num_layers, model_dim, inner_dim, ff_dim, num_heads, vocab_size, dtype, device): + super().__init__() + self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device) + self.block = torch.nn.ModuleList( + [ + T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) + for i in range(num_layers) + ] + ) + self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) + + def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True): + intermediate = None + x = self.embed_tokens(input_ids) + past_bias = None + for i, l in enumerate(self.block): + # uncomment to debug layerwise output: fp16 may cause issues + # print(i, x.mean(), x.std()) + x, past_bias = l(x, past_bias) + if i == intermediate_output: + intermediate = x.clone() + # print(x.mean(), x.std()) + x = self.final_layer_norm(x) + if intermediate is not None and final_layer_norm_intermediate: + intermediate = self.final_layer_norm(intermediate) + # print(x.mean(), x.std()) + return x, intermediate + + +class T5(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + super().__init__() + self.num_layers = config_dict["num_layers"] + self.encoder = T5Stack( + self.num_layers, + config_dict["d_model"], + config_dict["d_model"], + config_dict["d_ff"], + config_dict["num_heads"], + config_dict["vocab_size"], + dtype, + device, + ) + self.dtype = dtype + + def get_input_embeddings(self): + return self.encoder.embed_tokens + + def set_input_embeddings(self, embeddings): + self.encoder.embed_tokens = embeddings + + def forward(self, *args, **kwargs): + return self.encoder(*args, **kwargs) + + +def create_clip_l(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[str, torch.Tensor]] = None): + r""" + state_dict is not loaded, but updated with missing keys + """ + CLIPL_CONFIG = { + "hidden_act": "quick_gelu", + "hidden_size": 768, + "intermediate_size": 3072, + "num_attention_heads": 12, + "num_hidden_layers": 12, + } + with torch.no_grad(): + clip_l = SDClipModel( + layer="hidden", + layer_idx=-2, + device=device, + dtype=dtype, + layer_norm_hidden_state=False, + return_projected_pooled=False, + textmodel_json_config=CLIPL_CONFIG, + ) + if state_dict is not None: + # update state_dict if provided to include logit_scale and text_projection.weight avoid errors + if "logit_scale" not in state_dict: + state_dict["logit_scale"] = clip_l.logit_scale + if "transformer.text_projection.weight" not in state_dict: + state_dict["transformer.text_projection.weight"] = clip_l.transformer.text_projection.weight + return clip_l + + +def create_clip_g(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[str, torch.Tensor]] = None): + r""" + state_dict is not loaded, but updated with missing keys + """ + CLIPG_CONFIG = { + "hidden_act": "gelu", + "hidden_size": 1280, + "intermediate_size": 5120, + "num_attention_heads": 20, + "num_hidden_layers": 32, + } + with torch.no_grad(): + clip_g = SDXLClipG(CLIPG_CONFIG, device=device, dtype=dtype) + if state_dict is not None: + if "logit_scale" not in state_dict: + state_dict["logit_scale"] = clip_g.logit_scale + return clip_g + + +def create_t5xxl(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> T5XXLModel: + T5_CONFIG = {"d_ff": 10240, "d_model": 4096, "num_heads": 64, "num_layers": 24, "vocab_size": 32128} + with torch.no_grad(): + t5 = T5XXLModel(T5_CONFIG, dtype=dtype, device=device) + if state_dict is not None: + if "logit_scale" not in state_dict: + state_dict["logit_scale"] = t5.logit_scale + if "transformer.shared.weight" in state_dict: + state_dict.pop("transformer.shared.weight") + return t5 + + +""" + # snippet for using the T5 model from transformers + + from transformers import T5EncoderModel, T5Config + import accelerate + import json + + T5_CONFIG_JSON = "" +{ + "architectures": [ + "T5EncoderModel" + ], + "classifier_dropout": 0.0, + "d_ff": 10240, + "d_kv": 64, + "d_model": 4096, + "decoder_start_token_id": 0, + "dense_act_fn": "gelu_new", + "dropout_rate": 0.1, + "eos_token_id": 1, + "feed_forward_proj": "gated-gelu", + "initializer_factor": 1.0, + "is_encoder_decoder": true, + "is_gated_act": true, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "num_decoder_layers": 24, + "num_heads": 64, + "num_layers": 24, + "output_past": true, + "pad_token_id": 0, + "relative_attention_max_distance": 128, + "relative_attention_num_buckets": 32, + "tie_word_embeddings": false, + "torch_dtype": "float16", + "transformers_version": "4.41.2", + "use_cache": true, + "vocab_size": 32128 +} +"" + config = json.loads(T5_CONFIG_JSON) + config = T5Config(**config) + + # model = T5EncoderModel.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="text_encoder_3") + # print(model.config) + # # model(**load_model.config) + + # with accelerate.init_empty_weights(): + model = T5EncoderModel._from_config(config) # , torch_dtype=dtype) + for key in list(state_dict.keys()): + if key.startswith("transformer."): + new_key = key[len("transformer.") :] + state_dict[new_key] = state_dict.pop(key) + + info = model.load_state_dict(state_dict) + print(info) + model.set_attn_mode = lambda x: None + # model.to("cpu") + + _self = model + + def enc(list_of_token_weight_pairs): + has_batch = isinstance(list_of_token_weight_pairs[0][0], list) + + if has_batch: + list_of_tokens = [] + for pairs in list_of_token_weight_pairs: + tokens = [a[0] for a in pairs[0]] # I'm not sure why this is [0] + list_of_tokens.append(tokens) + else: + list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]] + + list_of_tokens = np.array(list_of_tokens) + list_of_tokens = torch.from_numpy(list_of_tokens).to("cuda", dtype=torch.long) + out = _self(list_of_tokens) + pooled = None + if has_batch: + return out, pooled + else: + if pooled is not None: + first_pooled = pooled[0:1] + else: + first_pooled = pooled + return out[0], first_pooled + # output = [out[0:1]] + # return torch.cat(output, dim=-2), first_pooled + + model.encode_token_weights = enc + + return model +""" + +# endregion diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py new file mode 100644 index 000000000..245912199 --- /dev/null +++ b/library/sd3_train_utils.py @@ -0,0 +1,656 @@ +import argparse +import glob +import math +import os +from typing import List, Optional, Tuple, Union + +import torch +from safetensors.torch import save_file +from accelerate import Accelerator + +from library import sd3_models, sd3_utils, train_util +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +# from transformers import CLIPTokenizer +# from library import model_util +# , sdxl_model_util, train_util, sdxl_original_unet +# from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from .sdxl_train_util import match_mixed_precision + + +def load_target_model( + model_type: str, + args: argparse.Namespace, + state_dict: dict, + accelerator: Accelerator, + attn_mode: str, + model_dtype: Optional[torch.dtype], + device: Optional[torch.device], +) -> Union[ + sd3_models.MMDiT, + Optional[sd3_models.SDClipModel], + Optional[sd3_models.SDXLClipG], + Optional[sd3_models.T5XXLModel], + sd3_models.SDVAE, +]: + loading_device = device if device is not None else (accelerator.device if args.lowram else "cpu") + + for pi in range(accelerator.state.num_processes): + if pi == accelerator.state.local_process_index: + logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") + + if model_type == "mmdit": + model = sd3_utils.load_mmdit(state_dict, attn_mode, model_dtype, loading_device) + elif model_type == "clip_l": + model = sd3_utils.load_clip_l(state_dict, args.clip_l, attn_mode, model_dtype, loading_device) + elif model_type == "clip_g": + model = sd3_utils.load_clip_g(state_dict, args.clip_g, attn_mode, model_dtype, loading_device) + elif model_type == "t5xxl": + model = sd3_utils.load_t5xxl(state_dict, args.t5xxl, attn_mode, model_dtype, loading_device) + elif model_type == "vae": + model = sd3_utils.load_vae(state_dict, args.vae, model_dtype, loading_device) + else: + raise ValueError(f"Unknown model type: {model_type}") + + # work on low-ram device: models are already loaded on accelerator.device, but we ensure they are on device + if args.lowram: + model = model.to(accelerator.device) + + clean_memory_on_device(accelerator.device) + accelerator.wait_for_everyone() + + return model + + +def save_models( + ckpt_path: str, + mmdit: sd3_models.MMDiT, + vae: sd3_models.SDVAE, + clip_l: sd3_models.SDClipModel, + clip_g: sd3_models.SDXLClipG, + t5xxl: Optional[sd3_models.T5XXLModel], + sai_metadata: Optional[dict], + save_dtype: Optional[torch.dtype] = None, +): + r""" + Save models to checkpoint file. Only supports unified checkpoint format. + """ + + state_dict = {} + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + if save_dtype is not None: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + update_sd("model.diffusion_model.", mmdit.state_dict()) + update_sd("first_stage_model.", vae.state_dict()) + + if clip_l is not None: + update_sd("text_encoders.clip_l.", clip_l.state_dict()) + if clip_g is not None: + update_sd("text_encoders.clip_g.", clip_g.state_dict()) + if t5xxl is not None: + update_sd("text_encoders.t5xxl.", t5xxl.state_dict()) + + save_file(state_dict, ckpt_path, metadata=sai_metadata) + + +def save_sd3_model_on_train_end( + args: argparse.Namespace, + save_dtype: torch.dtype, + epoch: int, + global_step: int, + clip_l: sd3_models.SDClipModel, + clip_g: sd3_models.SDXLClipG, + t5xxl: Optional[sd3_models.T5XXLModel], + mmdit: sd3_models.MMDiT, + vae: sd3_models.SDVAE, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec( + None, args, False, False, False, is_stable_diffusion_ckpt=True, sd3=mmdit.model_type + ) + save_models(ckpt_file, mmdit, vae, clip_l, clip_g, t5xxl, sai_metadata, save_dtype) + + train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None) + + +# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している +# on_epoch_end: Trueならepoch終了時、Falseならstep経過時 +def save_sd3_model_on_epoch_end_or_stepwise( + args: argparse.Namespace, + on_epoch_end: bool, + accelerator, + save_dtype: torch.dtype, + epoch: int, + num_train_epochs: int, + global_step: int, + clip_l: sd3_models.SDClipModel, + clip_g: sd3_models.SDXLClipG, + t5xxl: Optional[sd3_models.T5XXLModel], + mmdit: sd3_models.MMDiT, + vae: sd3_models.SDVAE, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec( + None, args, False, False, False, is_stable_diffusion_ckpt=True, sd3=mmdit.model_type + ) + save_models(ckpt_file, mmdit, vae, clip_l, clip_g, t5xxl, sai_metadata, save_dtype) + + train_util.save_sd_model_on_epoch_end_or_stepwise_common( + args, + on_epoch_end, + accelerator, + True, + True, + epoch, + num_train_epochs, + global_step, + sd_saver, + None, + ) + + +def add_sd3_training_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" + ) + parser.add_argument( + "--cache_text_encoder_outputs_to_disk", + action="store_true", + help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", + ) + parser.add_argument( + "--text_encoder_batch_size", + type=int, + default=None, + help="text encoder batch size (default: None, use dataset's batch size)" + + " / text encoderのバッチサイズ(デフォルト: None, データセットのバッチサイズを使用)", + ) + parser.add_argument( + "--disable_mmap_load_safetensors", + action="store_true", + help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる", + ) + + parser.add_argument( + "--clip_l", + type=str, + required=False, + help="CLIP-L model path. if not specified, use ckpt's state_dict / CLIP-Lモデルのパス。指定しない場合はckptのstate_dictを使用", + ) + parser.add_argument( + "--clip_g", + type=str, + required=False, + help="CLIP-G model path. if not specified, use ckpt's state_dict / CLIP-Gモデルのパス。指定しない場合はckptのstate_dictを使用", + ) + parser.add_argument( + "--t5xxl", + type=str, + required=False, + help="T5-XXL model path. if not specified, use ckpt's state_dict / T5-XXLモデルのパス。指定しない場合はckptのstate_dictを使用", + ) + parser.add_argument( + "--save_clip", action="store_true", help="save CLIP models to checkpoint / CLIPモデルをチェックポイントに保存する" + ) + parser.add_argument( + "--save_t5xxl", action="store_true", help="save T5-XXL model to checkpoint / T5-XXLモデルをチェックポイントに保存する" + ) + + parser.add_argument( + "--t5xxl_device", + type=str, + default=None, + help="T5-XXL device. if not specified, use accelerator's device / T5-XXLデバイス。指定しない場合はacceleratorのデバイスを使用", + ) + parser.add_argument( + "--t5xxl_dtype", + type=str, + default=None, + help="T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtype(mixed precisionから)を使用", + ) + + # copy from Diffusers + parser.add_argument( + "--weighting_scheme", + type=str, + default="logit_normal", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"], + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.") + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + + +def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): + assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" + if args.v_parameterization: + logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります") + + if args.clip_skip is not None: + logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません") + + # if args.multires_noise_iterations: + # logger.info( + # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります" + # ) + # else: + # if args.noise_offset is None: + # args.noise_offset = DEFAULT_NOISE_OFFSET + # elif args.noise_offset != DEFAULT_NOISE_OFFSET: + # logger.info( + # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています" + # ) + # logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") + + assert ( + not hasattr(args, "weighted_captions") or not args.weighted_captions + ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" + + if supportTextEncoderCaching: + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + args.cache_text_encoder_outputs = True + logger.warning( + "cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / " + + "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました" + ) + + +def sample_images(*args, **kwargs): + return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) + + +class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy): + SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz" + + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + self.vae = None + + def set_vae(self, vae: sd3_models.SDVAE): + self.vae = vae + + def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: + npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX) + if len(npz_file) == 0: + return None, None + w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x") + return int(w), int(h) + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + return ( + os.path.splitext(absolute_path)[0] + + f"_{image_size[0]:04d}x{image_size[1]:04d}" + + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX + ) + + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + if not self.cache_to_disk: + return False + if not os.path.exists(npz_path): + return False + if self.skip_disk_cache_validity_check: + return True + + expected_latents_size = (bucket_reso[1] // 8, bucket_reso[0] // 8) # bucket_reso is (W, H) + + try: + npz = np.load(npz_path) + if npz["latents"].shape[1:3] != expected_latents_size: + return False + + if flip_aug: + if "latents_flipped" not in npz: + return False + if npz["latents_flipped"].shape[1:3] != expected_latents_size: + return False + + if alpha_mask: + if "alpha_mask" not in npz: + return False + if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]): + return False + else: + if "alpha_mask" in npz: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e + + return True + + def cache_batch_latents(self, image_infos: List[train_util.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool): + img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching( + image_infos, alpha_mask, random_crop + ) + img_tensor = img_tensor.to(device=self.vae.device, dtype=self.vae.dtype) + + with torch.no_grad(): + latents_tensors = self.vae.encode(img_tensor).to("cpu") + if flip_aug: + img_tensor = torch.flip(img_tensor, dims=[3]) + with torch.no_grad(): + flipped_latents = self.vae.encode(img_tensor).to("cpu") + else: + flipped_latents = [None] * len(latents_tensors) + + # for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks): + for i in range(len(image_infos)): + info = image_infos[i] + latents = latents_tensors[i] + flipped_latent = flipped_latents[i] + alpha_mask = alpha_masks[i] + original_size = original_sizes[i] + crop_ltrb = crop_ltrbs[i] + + if self.cache_to_disk: + kwargs = {} + if flipped_latent is not None: + kwargs["latents_flipped"] = flipped_latent.float().cpu().numpy() + if alpha_mask is not None: + kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() + np.savez( + info.latents_npz, + latents=latents.float().cpu().numpy(), + original_size=np.array(original_size), + crop_ltrb=np.array(crop_ltrb), + **kwargs, + ) + else: + info.latents = latents + if flip_aug: + info.latents_flipped = flipped_latent + info.alpha_mask = alpha_mask + + if not train_util.HIGH_VRAM: + clean_memory_on_device(self.vae.device) + + +# region Diffusers + + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils.torch_utils import randn_tensor +from diffusers.utils import BaseOutput + + +@dataclass +class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Euler scheduler. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + ): + timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + + sigmas = timesteps / num_train_timesteps + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.timesteps = sigmas * num_train_timesteps + + self._step_index = None + self._begin_index = None + + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def scale_noise( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Forward process in flow-matching + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + sample = sigma * noise + (1.0 - sigma) * sample + + return sample + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + timesteps = np.linspace(self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps) + + sigmas = timesteps / self.config.num_train_timesteps + sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + + timesteps = sigmas * self.config.num_train_timesteps + self.timesteps = timesteps.to(device=device) + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + self._step_index = None + self._begin_index = None + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + sigma = self.sigmas[self.step_index] + + gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 + + noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator) + + eps = noise * s_noise + sigma_hat = sigma * (gamma + 1) + + if gamma > 0: + sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + # NOTE: "original_sample" should not be an expected prediction_type but is left in for + # backwards compatibility + + # if self.config.prediction_type == "vector_field": + + denoised = sample - model_output * sigma + # 2. Convert to an ODE derivative + derivative = (sample - denoised) / sigma_hat + + dt = self.sigmas[self.step_index + 1] - sigma_hat + + prev_sample = sample + derivative * dt + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) + + def __len__(self): + return self.config.num_train_timesteps + + +# endregion diff --git a/library/sd3_utils.py b/library/sd3_utils.py new file mode 100644 index 000000000..16f80c60d --- /dev/null +++ b/library/sd3_utils.py @@ -0,0 +1,513 @@ +import math +from typing import Dict, Optional, Union +import torch +import safetensors +from safetensors.torch import load_file +from accelerate import init_empty_weights + +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from library import sd3_models + +# TODO move some of functions to model_util.py +from library import sdxl_model_util + +# region models + + +def load_safetensors(path: str, dvc: Union[str, torch.device], disable_mmap: bool = False): + if disable_mmap: + return safetensors.torch.load(open(path, "rb").read()) + else: + try: + return load_file(path, device=dvc) + except: + return load_file(path) # prevent device invalid Error + + +def load_mmdit(state_dict: Dict, attn_mode: str, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device]): + mmdit_sd = {} + + mmdit_prefix = "model.diffusion_model." + for k in list(state_dict.keys()): + if k.startswith(mmdit_prefix): + mmdit_sd[k[len(mmdit_prefix) :]] = state_dict.pop(k) + + # load MMDiT + logger.info("Building MMDit") + with init_empty_weights(): + mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode) + + logger.info("Loading state dict...") + info = sdxl_model_util._load_state_dict_on_device(mmdit, mmdit_sd, device, dtype) + logger.info(f"Loaded MMDiT: {info}") + return mmdit + + +def load_clip_l( + state_dict: Dict, + clip_l_path: Optional[str], + attn_mode: str, + clip_dtype: Optional[Union[str, torch.dtype]], + device: Union[str, torch.device], + disable_mmap: bool = False, +): + clip_l_sd = None + if clip_l_path: + logger.info(f"Loading clip_l from {clip_l_path}...") + clip_l_sd = load_safetensors(clip_l_path, device, disable_mmap) + for key in list(clip_l_sd.keys()): + clip_l_sd["transformer." + key] = clip_l_sd.pop(key) + else: + if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_l: remove prefix "text_encoders.clip_l." + logger.info("clip_l is included in the checkpoint") + clip_l_sd = {} + prefix = "text_encoders.clip_l." + for k in list(state_dict.keys()): + if k.startswith(prefix): + clip_l_sd[k[len(prefix) :]] = state_dict.pop(k) + + if clip_l_sd is None: + clip_l = None + else: + logger.info("Building ClipL") + clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd) + logger.info("Loading state dict...") + info = clip_l.load_state_dict(clip_l_sd) + logger.info(f"Loaded ClipL: {info}") + clip_l.set_attn_mode(attn_mode) + return clip_l + + +def load_clip_g( + state_dict: Dict, + clip_g_path: Optional[str], + attn_mode: str, + clip_dtype: Optional[Union[str, torch.dtype]], + device: Union[str, torch.device], + disable_mmap: bool = False, +): + clip_g_sd = None + if clip_g_path: + logger.info(f"Loading clip_g from {clip_g_path}...") + clip_g_sd = load_safetensors(clip_g_path, device, disable_mmap) + for key in list(clip_g_sd.keys()): + clip_g_sd["transformer." + key] = clip_g_sd.pop(key) + else: + if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_g: remove prefix "text_encoders.clip_g." + logger.info("clip_g is included in the checkpoint") + clip_g_sd = {} + prefix = "text_encoders.clip_g." + for k in list(state_dict.keys()): + if k.startswith(prefix): + clip_g_sd[k[len(prefix) :]] = state_dict.pop(k) + + if clip_g_sd is None: + clip_g = None + else: + logger.info("Building ClipG") + clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd) + logger.info("Loading state dict...") + info = clip_g.load_state_dict(clip_g_sd) + logger.info(f"Loaded ClipG: {info}") + clip_g.set_attn_mode(attn_mode) + return clip_g + + +def load_t5xxl( + state_dict: Dict, + t5xxl_path: Optional[str], + attn_mode: str, + dtype: Optional[Union[str, torch.dtype]], + device: Union[str, torch.device], + disable_mmap: bool = False, +): + t5xxl_sd = None + if t5xxl_path: + logger.info(f"Loading t5xxl from {t5xxl_path}...") + t5xxl_sd = load_safetensors(t5xxl_path, device, disable_mmap) + for key in list(t5xxl_sd.keys()): + t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key) + else: + if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict: + # found t5xxl: remove prefix "text_encoders.t5xxl." + logger.info("t5xxl is included in the checkpoint") + t5xxl_sd = {} + prefix = "text_encoders.t5xxl." + for k in list(state_dict.keys()): + if k.startswith(prefix): + t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k) + + if t5xxl_sd is None: + t5xxl = None + else: + logger.info("Building T5XXL") + + # workaround for T5XXL model creation: create with fp16 takes too long TODO support virtual device + t5xxl = sd3_models.create_t5xxl(device, torch.float32, t5xxl_sd) + t5xxl.to(dtype=dtype) + + logger.info("Loading state dict...") + info = t5xxl.load_state_dict(t5xxl_sd) + logger.info(f"Loaded T5XXL: {info}") + t5xxl.set_attn_mode(attn_mode) + return t5xxl + + +def load_vae( + state_dict: Dict, + vae_path: Optional[str], + vae_dtype: Optional[Union[str, torch.dtype]], + device: Optional[Union[str, torch.device]], + disable_mmap: bool = False, +): + vae_sd = {} + if vae_path: + logger.info(f"Loading VAE from {vae_path}...") + vae_sd = load_safetensors(vae_path, device, disable_mmap) + else: + # remove prefix "first_stage_model." + vae_sd = {} + vae_prefix = "first_stage_model." + for k in list(state_dict.keys()): + if k.startswith(vae_prefix): + vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k) + + logger.info("Building VAE") + vae = sd3_models.SDVAE() + logger.info("Loading state dict...") + info = vae.load_state_dict(vae_sd) + logger.info(f"Loaded VAE: {info}") + vae.to(device=device, dtype=vae_dtype) + return vae + + +def load_models( + ckpt_path: str, + clip_l_path: str, + clip_g_path: str, + t5xxl_path: str, + vae_path: str, + attn_mode: str, + device: Union[str, torch.device], + weight_dtype: Optional[Union[str, torch.dtype]] = None, + disable_mmap: bool = False, + clip_dtype: Optional[Union[str, torch.dtype]] = None, + t5xxl_device: Optional[Union[str, torch.device]] = None, + t5xxl_dtype: Optional[Union[str, torch.dtype]] = None, + vae_dtype: Optional[Union[str, torch.dtype]] = None, +): + """ + Load SD3 models from checkpoint files. + + Args: + ckpt_path: Path to the SD3 checkpoint file. + clip_l_path: Path to the clip_l checkpoint file. + clip_g_path: Path to the clip_g checkpoint file. + t5xxl_path: Path to the t5xxl checkpoint file. + vae_path: Path to the VAE checkpoint file. + attn_mode: Attention mode for MMDiT model. + device: Device for MMDiT model. + weight_dtype: Default dtype of weights for all models. This is weight dtype, so the model dtype may be different. + disable_mmap: Disable memory mapping when loading state dict. + clip_dtype: Dtype for Clip models, or None to use default dtype. + t5xxl_device: Device for T5XXL model to load T5XXL in another device (eg. gpu). Default is None to use device. + t5xxl_dtype: Dtype for T5XXL model, or None to use default dtype. + vae_dtype: Dtype for VAE model, or None to use default dtype. + + Returns: + Tuple of MMDiT, ClipL, ClipG, T5XXL, and VAE models. + """ + + # In SD1/2 and SDXL, the model is created with empty weights and then loaded with state dict. + # However, in SD3, Clip and T5XXL models are created with dtype, so we need to set dtype before loading state dict. + # Therefore, we need clip_dtype and t5xxl_dtype. + + def load_state_dict(path: str, dvc: Union[str, torch.device] = device): + if disable_mmap: + return safetensors.torch.load(open(path, "rb").read()) + else: + try: + return load_file(path, device=dvc) + except: + return load_file(path) # prevent device invalid Error + + t5xxl_device = t5xxl_device or device + clip_dtype = clip_dtype or weight_dtype or torch.float32 + t5xxl_dtype = t5xxl_dtype or weight_dtype or torch.float32 + vae_dtype = vae_dtype or weight_dtype or torch.float32 + + logger.info(f"Loading SD3 models from {ckpt_path}...") + state_dict = load_state_dict(ckpt_path) + + # load clip_l + clip_l_sd = None + if clip_l_path: + logger.info(f"Loading clip_l from {clip_l_path}...") + clip_l_sd = load_state_dict(clip_l_path) + for key in list(clip_l_sd.keys()): + clip_l_sd["transformer." + key] = clip_l_sd.pop(key) + else: + if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_l: remove prefix "text_encoders.clip_l." + logger.info("clip_l is included in the checkpoint") + clip_l_sd = {} + prefix = "text_encoders.clip_l." + for k in list(state_dict.keys()): + if k.startswith(prefix): + clip_l_sd[k[len(prefix) :]] = state_dict.pop(k) + + # load clip_g + clip_g_sd = None + if clip_g_path: + logger.info(f"Loading clip_g from {clip_g_path}...") + clip_g_sd = load_state_dict(clip_g_path) + for key in list(clip_g_sd.keys()): + clip_g_sd["transformer." + key] = clip_g_sd.pop(key) + else: + if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_g: remove prefix "text_encoders.clip_g." + logger.info("clip_g is included in the checkpoint") + clip_g_sd = {} + prefix = "text_encoders.clip_g." + for k in list(state_dict.keys()): + if k.startswith(prefix): + clip_g_sd[k[len(prefix) :]] = state_dict.pop(k) + + # load t5xxl + t5xxl_sd = None + if t5xxl_path: + logger.info(f"Loading t5xxl from {t5xxl_path}...") + t5xxl_sd = load_state_dict(t5xxl_path, t5xxl_device) + for key in list(t5xxl_sd.keys()): + t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key) + else: + if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict: + # found t5xxl: remove prefix "text_encoders.t5xxl." + logger.info("t5xxl is included in the checkpoint") + t5xxl_sd = {} + prefix = "text_encoders.t5xxl." + for k in list(state_dict.keys()): + if k.startswith(prefix): + t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k) + + # MMDiT and VAE + vae_sd = {} + if vae_path: + logger.info(f"Loading VAE from {vae_path}...") + vae_sd = load_state_dict(vae_path) + else: + # remove prefix "first_stage_model." + vae_sd = {} + vae_prefix = "first_stage_model." + for k in list(state_dict.keys()): + if k.startswith(vae_prefix): + vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k) + + mmdit_prefix = "model.diffusion_model." + for k in list(state_dict.keys()): + if k.startswith(mmdit_prefix): + state_dict[k[len(mmdit_prefix) :]] = state_dict.pop(k) + else: + state_dict.pop(k) # remove other keys + + # load MMDiT + logger.info("Building MMDit") + with init_empty_weights(): + mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode) + + logger.info("Loading state dict...") + info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, weight_dtype) + logger.info(f"Loaded MMDiT: {info}") + + # load ClipG and ClipL + if clip_l_sd is None: + clip_l = None + else: + logger.info("Building ClipL") + clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd) + logger.info("Loading state dict...") + info = clip_l.load_state_dict(clip_l_sd) + logger.info(f"Loaded ClipL: {info}") + clip_l.set_attn_mode(attn_mode) + + if clip_g_sd is None: + clip_g = None + else: + logger.info("Building ClipG") + clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd) + logger.info("Loading state dict...") + info = clip_g.load_state_dict(clip_g_sd) + logger.info(f"Loaded ClipG: {info}") + clip_g.set_attn_mode(attn_mode) + + # load T5XXL + if t5xxl_sd is None: + t5xxl = None + else: + logger.info("Building T5XXL") + t5xxl = sd3_models.create_t5xxl(t5xxl_device, t5xxl_dtype, t5xxl_sd) + logger.info("Loading state dict...") + info = t5xxl.load_state_dict(t5xxl_sd) + logger.info(f"Loaded T5XXL: {info}") + t5xxl.set_attn_mode(attn_mode) + + # load VAE + logger.info("Building VAE") + vae = sd3_models.SDVAE() + logger.info("Loading state dict...") + info = vae.load_state_dict(vae_sd) + logger.info(f"Loaded VAE: {info}") + vae.to(device=device, dtype=vae_dtype) + + return mmdit, clip_l, clip_g, t5xxl, vae + + +# endregion +# region utils + + +def get_cond( + prompt: str, + tokenizer: sd3_models.SD3Tokenizer, + clip_l: sd3_models.SDClipModel, + clip_g: sd3_models.SDXLClipG, + t5xxl: Optional[sd3_models.T5XXLModel] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +): + l_tokens, g_tokens, t5_tokens = tokenizer.tokenize_with_weights(prompt) + return get_cond_from_tokens(l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, device=device, dtype=dtype) + + +def get_cond_from_tokens( + l_tokens, + g_tokens, + t5_tokens, + clip_l: sd3_models.SDClipModel, + clip_g: sd3_models.SDXLClipG, + t5xxl: Optional[sd3_models.T5XXLModel] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +): + l_out, l_pooled = clip_l.encode_token_weights(l_tokens) + g_out, g_pooled = clip_g.encode_token_weights(g_tokens) + lg_out = torch.cat([l_out, g_out], dim=-1) + lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) + if device is not None: + lg_out = lg_out.to(device=device) + l_pooled = l_pooled.to(device=device) + g_pooled = g_pooled.to(device=device) + if dtype is not None: + lg_out = lg_out.to(dtype=dtype) + l_pooled = l_pooled.to(dtype=dtype) + g_pooled = g_pooled.to(dtype=dtype) + + # t5xxl may be in another device (eg. cpu) + if t5_tokens is None: + t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype) + else: + t5_out, _ = t5xxl.encode_token_weights(t5_tokens) # t5_out is [1, 77, 4096], t5_pooled is None + if device is not None: + t5_out = t5_out.to(device=device) + if dtype is not None: + t5_out = t5_out.to(dtype=dtype) + + # return torch.cat([lg_out, t5_out], dim=-2), torch.cat((l_pooled, g_pooled), dim=-1) + return lg_out, t5_out, torch.cat((l_pooled, g_pooled), dim=-1) + + +# used if other sd3 models is available +r""" +def get_sd3_configs(state_dict: Dict): + # Important configuration values can be quickly determined by checking shapes in the source file + # Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change) + # prefix = "model.diffusion_model." + prefix = "" + + patch_size = state_dict[prefix + "x_embedder.proj.weight"].shape[2] + depth = state_dict[prefix + "x_embedder.proj.weight"].shape[0] // 64 + num_patches = state_dict[prefix + "pos_embed"].shape[1] + pos_embed_max_size = round(math.sqrt(num_patches)) + adm_in_channels = state_dict[prefix + "y_embedder.mlp.0.weight"].shape[1] + context_shape = state_dict[prefix + "context_embedder.weight"].shape + context_embedder_config = { + "target": "torch.nn.Linear", + "params": {"in_features": context_shape[1], "out_features": context_shape[0]}, + } + return { + "patch_size": patch_size, + "depth": depth, + "num_patches": num_patches, + "pos_embed_max_size": pos_embed_max_size, + "adm_in_channels": adm_in_channels, + "context_embedder": context_embedder_config, + } + + +def create_mmdit_from_sd3_checkpoint(state_dict: Dict, attn_mode: str = "xformers"): + "" + Doesn't load state dict. + "" + sd3_configs = get_sd3_configs(state_dict) + + mmdit = sd3_models.MMDiT( + input_size=None, + pos_embed_max_size=sd3_configs["pos_embed_max_size"], + patch_size=sd3_configs["patch_size"], + in_channels=16, + adm_in_channels=sd3_configs["adm_in_channels"], + depth=sd3_configs["depth"], + mlp_ratio=4, + qk_norm=None, + num_patches=sd3_configs["num_patches"], + context_size=4096, + attn_mode=attn_mode, + ) + return mmdit +""" + + +class ModelSamplingDiscreteFlow: + """Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models""" + + def __init__(self, shift=1.0): + self.shift = shift + timesteps = 1000 + self.sigmas = self.sigma(torch.arange(1, timesteps + 1, 1)) + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def timestep(self, sigma): + return sigma * 1000 + + def sigma(self, timestep: torch.Tensor): + timestep = timestep / 1000.0 + if self.shift == 1.0: + return timestep + return self.shift * timestep / (1 + (self.shift - 1) * timestep) + + def calculate_denoised(self, sigma, model_output, model_input): + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + return model_input - model_output * sigma + + def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): + # assert max_denoise is False, "max_denoise not implemented" + # max_denoise is always True, I'm not sure why it's there + return sigma * noise + (1.0 - sigma) * latent_image + + +# endregion diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index f03f1bae5..4fad78a1c 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -1,4 +1,5 @@ import torch +import safetensors from accelerate import init_empty_weights from accelerate.utils.modeling import set_module_tensor_to_device from safetensors.torch import load_file, save_file @@ -8,8 +9,10 @@ from library import model_util from library import sdxl_original_unet from .utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) VAE_SCALE_FACTOR = 0.13025 @@ -163,17 +166,20 @@ def _load_state_dict_on_device(model, state_dict, device, dtype=None): raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))) -def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None): +def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None, disable_mmap=False): # model_version is reserved for future use # dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching # Load the state dict if model_util.is_safetensors(ckpt_path): checkpoint = None - try: - state_dict = load_file(ckpt_path, device=map_location) - except: - state_dict = load_file(ckpt_path) # prevent device invalid Error + if disable_mmap: + state_dict = safetensors.torch.load(open(ckpt_path, "rb").read()) + else: + try: + state_dict = load_file(ckpt_path, device=map_location) + except: + state_dict = load_file(ckpt_path) # prevent device invalid Error epoch = None global_step = None else: diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index a29013e34..6726ca07c 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -5,6 +5,7 @@ import torch from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from accelerate import init_empty_weights @@ -13,8 +14,10 @@ from library import model_util, sdxl_model_util, train_util, sdxl_original_unet from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline from .utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) TOKENIZER1_PATH = "openai/clip-vit-large-patch14" @@ -44,6 +47,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): weight_dtype, accelerator.device if args.lowram else "cpu", model_dtype, + # args.disable_mmap_load_safetensors, ) # work on low-ram device @@ -60,7 +64,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): def _load_target_model( - name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None + name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None, disable_mmap=False ): # model_dtype only work with full fp16/bf16 name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path @@ -75,7 +79,7 @@ def _load_target_model( unet, logit_scale, ckpt_info, - ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype) + ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype, disable_mmap) else: # Diffusers model is loaded to CPU from diffusers import StableDiffusionXLPipeline @@ -332,6 +336,11 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser): action="store_true", help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", ) + parser.add_argument( + "--disable_mmap_load_safetensors", + action="store_true", + help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる", + ) def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): diff --git a/library/train_util.py b/library/train_util.py index 0fec565db..e0c8a0a45 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -58,7 +58,7 @@ KDPM2AncestralDiscreteScheduler, AutoencoderKL, ) -from library import custom_train_functions +from library import custom_train_functions, sd3_utils from library.original_unet import UNet2DConditionModel from huggingface_hub import hf_hub_download import numpy as np @@ -135,6 +135,7 @@ ) TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" +TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" class ImageInfo: @@ -159,6 +160,7 @@ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, self.text_encoder_outputs1: Optional[torch.Tensor] = None self.text_encoder_outputs2: Optional[torch.Tensor] = None self.text_encoder_pool2: Optional[torch.Tensor] = None + self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime class BucketManager: @@ -357,10 +359,52 @@ def get_augmentor(self, use_color_aug: bool): # -> Optional[Callable[[np.ndarra return self.color_aug if use_color_aug else None +class LatentsCachingStrategy: + _strategy = None # strategy instance: actual strategy class + + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + self._cache_to_disk = cache_to_disk + self._batch_size = batch_size + self.skip_disk_cache_validity_check = skip_disk_cache_validity_check + + @classmethod + def set_strategy(cls, strategy): + if cls._strategy is not None: + raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["LatentsCachingStrategy"]: + return cls._strategy + + @property + def cache_to_disk(self): + return self._cache_to_disk + + @property + def batch_size(self): + return self._batch_size + + def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: + raise NotImplementedError + + def get_latents_npz_path(self, absolute_path: str, bucket_reso: Tuple[int, int]) -> str: + raise NotImplementedError + + def is_disk_cached_latents_expected( + self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool + ) -> bool: + raise NotImplementedError + + def cache_batch_latents(self, batch: List[ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool): + raise NotImplementedError + + class BaseSubset: def __init__( self, image_dir: Optional[str], + alpha_mask: Optional[bool], num_repeats: int, shuffle_caption: bool, caption_separator: str, @@ -381,6 +425,7 @@ def __init__( token_warmup_step: Union[float, int], ) -> None: self.image_dir = image_dir + self.alpha_mask = alpha_mask if alpha_mask is not None else False self.num_repeats = num_repeats self.shuffle_caption = shuffle_caption self.caption_separator = caption_separator @@ -412,6 +457,7 @@ def __init__( class_tokens: Optional[str], caption_extension: str, cache_info: bool, + alpha_mask: bool, num_repeats, shuffle_caption, caption_separator: str, @@ -435,6 +481,7 @@ def __init__( super().__init__( image_dir, + alpha_mask, num_repeats, shuffle_caption, caption_separator, @@ -473,6 +520,7 @@ def __init__( self, image_dir, metadata_file: str, + alpha_mask: bool, num_repeats, shuffle_caption, caption_separator, @@ -496,6 +544,7 @@ def __init__( super().__init__( image_dir, + alpha_mask, num_repeats, shuffle_caption, caption_separator, @@ -554,6 +603,7 @@ def __init__( super().__init__( image_dir, + False, # alpha_mask num_repeats, shuffle_caption, caption_separator, @@ -649,8 +699,16 @@ def set_caching_mode(self, mode): def set_current_epoch(self, epoch): if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする - self.shuffle_buckets() - self.current_epoch = epoch + if epoch > self.current_epoch: + logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) + num_epochs = epoch - self.current_epoch + for _ in range(num_epochs): + self.current_epoch += 1 + self.shuffle_buckets() + # self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader? + else: + logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) + self.current_epoch = epoch def set_current_step(self, step): self.current_step = step @@ -915,7 +973,7 @@ def make_buckets(self): logger.info(f"mean ar error (without repeats): {mean_img_ar_error}") # データ参照用indexを作る。このindexはdatasetのshuffleに用いられる - self.buckets_indices: List(BucketBatchIndex) = [] + self.buckets_indices: List[BucketBatchIndex] = [] for bucket_index, bucket in enumerate(self.bucket_manager.buckets): batch_count = int(math.ceil(len(bucket) / self.batch_size)) for batch_index in range(batch_count): @@ -969,7 +1027,70 @@ def is_text_encoder_output_cacheable(self): ] ) - def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): + def new_cache_latents(self, is_main_process: bool, caching_strategy: LatentsCachingStrategy): + r""" + a brand new method to cache latents. This method caches latents with caching strategy. + normal cache_latents method is used by default, but this method is used when caching strategy is specified. + """ + logger.info("caching latents with caching strategy.") + image_infos = list(self.image_data.values()) + + # sort by resolution + image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1]) + + # split by resolution + batches = [] + batch = [] + logger.info("checking cache validity...") + for info in tqdm(image_infos): + subset = self.image_to_subset[info.image_key] + + if info.latents_npz is not None: # fine tuning dataset + continue + + # check disk cache exists and size of latents + if caching_strategy.cache_to_disk: + # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix + info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size) + if not is_main_process: # prepare for multi-gpu, only store to info + continue + + cache_available = caching_strategy.is_disk_cached_latents_expected( + info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask + ) + if cache_available: # do not add to batch + continue + + # if last member of batch has different resolution, flush the batch + if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso: + batches.append(batch) + batch = [] + + batch.append(info) + + # if number of data in batch is enough, flush the batch + if len(batch) >= caching_strategy.batch_size: + batches.append(batch) + batch = [] + + if len(batch) > 0: + batches.append(batch) + + # if cache to disk, don't cache latents in non-main process, set to info only + if caching_strategy.cache_to_disk and not is_main_process: + return + + if len(batches) == 0: + logger.info("no latents to cache") + return + + # iterate batches: batch doesn't have image here. image will be loaded in cache_batch_latents and discarded + logger.info("caching latents...") + for batch in tqdm(batches, smoothing=1, total=len(batches)): + # cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + caching_strategy.cache_batch_latents(batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと logger.info("caching latents.") @@ -990,11 +1111,13 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc # check disk cache exists and size of latents if cache_to_disk: - info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz" + info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix if not is_main_process: # store to info only continue - cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug) + cache_available = is_disk_cached_latents_is_expected( + info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask + ) if cache_available: # do not add to batch continue @@ -1020,19 +1143,54 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc # iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded logger.info("caching latents...") for batch in tqdm(batches, smoothing=1, total=len(batches)): - cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.random_crop) + cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) - # weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる - # SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する - # SD1/2に対応するにはv2のフラグを持つ必要があるので後回し + # if weight_dtype is specified, Text Encoder itself and output will be converted to the dtype + # this method is only for SDXL, but it should be implemented here because it needs to be a method of dataset + # to support SD1/2, it needs a flag for v2, but it is postponed def cache_text_encoder_outputs( - self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True + self, tokenizers, text_encoders, device, output_dtype, cache_to_disk=False, is_main_process=True ): assert len(tokenizers) == 2, "only support SDXL" + return self.cache_text_encoder_outputs_common( + tokenizers, text_encoders, [device, device], output_dtype, [output_dtype], cache_to_disk, is_main_process + ) + + # same as above, but for SD3 + def cache_text_encoder_outputs_sd3( + self, tokenizer, text_encoders, devices, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True, batch_size=None + ): + return self.cache_text_encoder_outputs_common( + [tokenizer], + text_encoders, + devices, + output_dtype, + te_dtypes, + cache_to_disk, + is_main_process, + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3, + batch_size, + ) + def cache_text_encoder_outputs_common( + self, + tokenizers, + text_encoders, + devices, + output_dtype, + te_dtypes, + cache_to_disk=False, + is_main_process=True, + file_suffix=TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX, + batch_size=None, + ): # latentsのキャッシュと同様に、ディスクへのキャッシュに対応する # またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと logger.info("caching text encoder outputs.") + + if batch_size is None: + batch_size = self.batch_size + image_infos = list(self.image_data.values()) logger.info("checking cache existence...") @@ -1040,13 +1198,14 @@ def cache_text_encoder_outputs( for info in tqdm(image_infos): # subset = self.image_to_subset[info.image_key] if cache_to_disk: - te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX + te_out_npz = os.path.splitext(info.absolute_path)[0] + file_suffix info.text_encoder_outputs_npz = te_out_npz if not is_main_process: # store to info only continue if os.path.exists(te_out_npz): + # TODO check varidity of cache here continue image_infos_to_cache.append(info) @@ -1055,20 +1214,25 @@ def cache_text_encoder_outputs( return # prepare tokenizers and text encoders - for text_encoder in text_encoders: + for text_encoder, device, te_dtype in zip(text_encoders, devices, te_dtypes): text_encoder.to(device) - if weight_dtype is not None: - text_encoder.to(dtype=weight_dtype) + if te_dtype is not None: + text_encoder.to(dtype=te_dtype) # create batch + is_sd3 = len(tokenizers) == 1 batch = [] batches = [] for info in image_infos_to_cache: - input_ids1 = self.get_input_ids(info.caption, tokenizers[0]) - input_ids2 = self.get_input_ids(info.caption, tokenizers[1]) - batch.append((info, input_ids1, input_ids2)) + if not is_sd3: + input_ids1 = self.get_input_ids(info.caption, tokenizers[0]) + input_ids2 = self.get_input_ids(info.caption, tokenizers[1]) + batch.append((info, input_ids1, input_ids2)) + else: + l_tokens, g_tokens, t5_tokens = tokenizers[0].tokenize_with_weights(info.caption) + batch.append((info, l_tokens, g_tokens, t5_tokens)) - if len(batch) >= self.batch_size: + if len(batch) >= batch_size: batches.append(batch) batch = [] @@ -1077,19 +1241,38 @@ def cache_text_encoder_outputs( # iterate batches: call text encoder and cache outputs for memory or disk logger.info("caching text encoder outputs...") - for batch in tqdm(batches): - infos, input_ids1, input_ids2 = zip(*batch) - input_ids1 = torch.stack(input_ids1, dim=0) - input_ids2 = torch.stack(input_ids2, dim=0) - cache_batch_text_encoder_outputs( - infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, weight_dtype - ) + if not is_sd3: + for batch in tqdm(batches): + infos, input_ids1, input_ids2 = zip(*batch) + input_ids1 = torch.stack(input_ids1, dim=0) + input_ids2 = torch.stack(input_ids2, dim=0) + cache_batch_text_encoder_outputs( + infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, output_dtype + ) + else: + for batch in tqdm(batches): + infos, l_tokens, g_tokens, t5_tokens = zip(*batch) + + # stack tokens + # l_tokens = [tokens[0] for tokens in l_tokens] + # g_tokens = [tokens[0] for tokens in g_tokens] + # t5_tokens = [tokens[0] for tokens in t5_tokens] + + cache_batch_text_encoder_outputs_sd3( + infos, + tokenizers[0], + text_encoders, + self.max_token_length, + cache_to_disk, + (l_tokens, g_tokens, t5_tokens), + output_dtype, + ) def get_image_size(self, image_path): return imagesize.get(image_path) - def load_image_with_face_info(self, subset: BaseSubset, image_path: str): - img = load_image(image_path) + def load_image_with_face_info(self, subset: BaseSubset, image_path: str, alpha_mask=False): + img = load_image(image_path, alpha_mask) face_cx = face_cy = face_w = face_h = 0 if subset.face_crop_aug_range is not None: @@ -1166,6 +1349,7 @@ def __getitem__(self, index): input_ids_list = [] input_ids2_list = [] latents_list = [] + alpha_mask_list = [] images = [] original_sizes_hw = [] crop_top_lefts = [] @@ -1190,21 +1374,28 @@ def __getitem__(self, index): crop_ltrb = image_info.latents_crop_ltrb # calc values later if flipped if not flipped: latents = image_info.latents + alpha_mask = image_info.alpha_mask else: latents = image_info.latents_flipped + alpha_mask = None if image_info.alpha_mask is None else torch.flip(image_info.alpha_mask, [1]) image = None elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 - latents, original_size, crop_ltrb, flipped_latents = load_latents_from_disk(image_info.latents_npz) + latents, original_size, crop_ltrb, flipped_latents, alpha_mask = load_latents_from_disk(image_info.latents_npz) if flipped: latents = flipped_latents + alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1].copy() # copy to avoid negative stride problem del flipped_latents latents = torch.FloatTensor(latents) + if alpha_mask is not None: + alpha_mask = torch.FloatTensor(alpha_mask) image = None else: # 画像を読み込み、必要ならcropする - img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path) + img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info( + subset, image_info.absolute_path, subset.alpha_mask + ) im_h, im_w = img.shape[0:2] if self.enable_bucket: @@ -1236,16 +1427,33 @@ def __getitem__(self, index): # augmentation aug = self.aug_helper.get_augmentor(subset.color_aug) if aug is not None: - img = aug(image=img)["image"] + # augment RGB channels only + img_rgb = img[:, :, :3] + img_rgb = aug(image=img_rgb)["image"] + img[:, :, :3] = img_rgb if flipped: img = img[:, ::-1, :].copy() # copy to avoid negative stride problem + if subset.alpha_mask: + if img.shape[2] == 4: + alpha_mask = img[:, :, 3] # [H,W] + alpha_mask = alpha_mask.astype(np.float32) / 255.0 # 0.0~1.0 + alpha_mask = torch.FloatTensor(alpha_mask) + else: + alpha_mask = torch.ones((img.shape[0], img.shape[1]), dtype=torch.float32) + else: + alpha_mask = None + + img = img[:, :, :3] # remove alpha channel + latents = None image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる + del img images.append(image) latents_list.append(latents) + alpha_mask_list.append(alpha_mask) target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8) @@ -1289,6 +1497,7 @@ def __getitem__(self, index): captions.append(caption) if not self.token_padding_disabled: # this option might be omitted in future + # TODO get_input_ids must support SD3 if self.XTI_layers: token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) else: @@ -1331,6 +1540,23 @@ def __getitem__(self, index): example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list) example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list) + # if one of alpha_masks is not None, we need to replace None with ones + none_or_not = [x is None for x in alpha_mask_list] + if all(none_or_not): + example["alpha_masks"] = None + elif any(none_or_not): + for i in range(len(alpha_mask_list)): + if alpha_mask_list[i] is None: + if images[i] is not None: + alpha_mask_list[i] = torch.ones((images[i].shape[1], images[i].shape[2]), dtype=torch.float32) + else: + alpha_mask_list[i] = torch.ones( + (latents_list[i].shape[1] * 8, latents_list[i].shape[2] * 8), dtype=torch.float32 + ) + example["alpha_masks"] = torch.stack(alpha_mask_list) + else: + example["alpha_masks"] = torch.stack(alpha_mask_list) + if images[0] is not None: images = torch.stack(images) images = images.to(memory_format=torch.contiguous_format).float() @@ -1361,6 +1587,7 @@ def get_item_for_caching(self, bucket, bucket_batch_size, image_index): resized_sizes = [] bucket_reso = None flip_aug = None + alpha_mask = None random_crop = None for image_key in bucket[image_index : image_index + bucket_batch_size]: @@ -1369,10 +1596,13 @@ def get_item_for_caching(self, bucket, bucket_batch_size, image_index): if flip_aug is None: flip_aug = subset.flip_aug + alpha_mask = subset.alpha_mask random_crop = subset.random_crop bucket_reso = image_info.bucket_reso else: + # TODO そもそも混在してても動くようにしたほうがいい assert flip_aug == subset.flip_aug, "flip_aug must be same in a batch" + assert alpha_mask == subset.alpha_mask, "alpha_mask must be same in a batch" assert random_crop == subset.random_crop, "random_crop must be same in a batch" assert bucket_reso == image_info.bucket_reso, "bucket_reso must be same in a batch" @@ -1409,6 +1639,7 @@ def get_item_for_caching(self, bucket, bucket_batch_size, image_index): example["absolute_paths"] = absolute_paths example["resized_sizes"] = resized_sizes example["flip_aug"] = flip_aug + example["alpha_mask"] = alpha_mask example["random_crop"] = random_crop example["bucket_reso"] = bucket_reso return example @@ -1489,7 +1720,7 @@ def read_caption(img_path, caption_extension, enable_wildcard): def load_dreambooth_dir(subset: DreamBoothSubset): if not os.path.isdir(subset.image_dir): logger.warning(f"not directory: {subset.image_dir}") - return [], [], [] + return [], [] info_cache_file = os.path.join(subset.image_dir, self.IMAGE_INFO_CACHE_FILE) use_cached_info_for_subset = subset.cache_info @@ -1516,6 +1747,18 @@ def load_dreambooth_dir(subset: DreamBoothSubset): img_paths = glob_images(subset.image_dir, "*") sizes = [None] * len(img_paths) + # new caching: get image size from cache files + strategy = LatentsCachingStrategy.get_strategy() + if strategy is not None: + logger.info("get image size from cache files") + size_set_count = 0 + for i, img_path in enumerate(tqdm(img_paths)): + w, h = strategy.get_image_size_from_image_absolute_path(img_path) + if w is not None and h is not None: + sizes[i] = [w, h] + size_set_count += 1 + logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") + logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") if use_cached_info_for_subset: @@ -1892,6 +2135,7 @@ def __init__( None, subset.caption_extension, subset.cache_info, + False, subset.num_repeats, subset.shuffle_caption, subset.caption_separator, @@ -2074,10 +2318,15 @@ def enable_XTI(self, *args, **kwargs): for dataset in self.datasets: dataset.enable_XTI(*args, **kwargs) - def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): for i, dataset in enumerate(self.datasets): logger.info(f"[Dataset {i}]") - dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) + dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix) + + def new_cache_latents(self, is_main_process: bool, strategy: LatentsCachingStrategy): + for i, dataset in enumerate(self.datasets): + logger.info(f"[Dataset {i}]") + dataset.new_cache_latents(is_main_process, strategy) def cache_text_encoder_outputs( self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True @@ -2086,6 +2335,15 @@ def cache_text_encoder_outputs( logger.info(f"[Dataset {i}]") dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process) + def cache_text_encoder_outputs_sd3( + self, tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True, batch_size=None + ): + for i, dataset in enumerate(self.datasets): + logger.info(f"[Dataset {i}]") + dataset.cache_text_encoder_outputs_sd3( + tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size + ) + def set_caching_mode(self, caching_mode): for dataset in self.datasets: dataset.set_caching_mode(caching_mode) @@ -2117,31 +2375,44 @@ def disable_token_padding(self): dataset.disable_token_padding() -def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): +def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alpha_mask: bool): expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意 if not os.path.exists(npz_path): return False - npz = np.load(npz_path) - if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: # old ver? - return False - if npz["latents"].shape[1:3] != expected_latents_size: - return False - - if flip_aug: - if "latents_flipped" not in npz: + try: + npz = np.load(npz_path) + if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: # old ver? return False - if npz["latents_flipped"].shape[1:3] != expected_latents_size: + if npz["latents"].shape[1:3] != expected_latents_size: return False + if flip_aug: + if "latents_flipped" not in npz: + return False + if npz["latents_flipped"].shape[1:3] != expected_latents_size: + return False + + if alpha_mask: + if "alpha_mask" not in npz: + return False + if npz["alpha_mask"].shape[0:2] != reso: # HxW + return False + else: + if "alpha_mask" in npz: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e + return True # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) def load_latents_from_disk( npz_path, -) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor]]: +) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: npz = np.load(npz_path) if "latents" not in npz: raise ValueError(f"error: npz is old format. please re-generate {npz_path}") @@ -2150,13 +2421,16 @@ def load_latents_from_disk( original_size = npz["original_size"].tolist() crop_ltrb = npz["crop_ltrb"].tolist() flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None - return latents, original_size, crop_ltrb, flipped_latents + alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None + return latents, original_size, crop_ltrb, flipped_latents, alpha_mask -def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None): +def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None): kwargs = {} if flipped_latents_tensor is not None: kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() + if alpha_mask is not None: + kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() np.savez( npz_path, latents=latents_tensor.float().cpu().numpy(), @@ -2228,6 +2502,13 @@ def debug_dataset(train_dataset, show_input_ids=False): if os.name == "nt": cv2.imshow("cond_img", cond_img) + if "alpha_masks" in example and example["alpha_masks"] is not None: + alpha_mask = example["alpha_masks"][j] + logger.info(f"alpha mask size: {alpha_mask.size()}") + alpha_mask = (alpha_mask[0].numpy() * 255.0).astype(np.uint8) + if os.name == "nt": + cv2.imshow("alpha_mask", alpha_mask) + if os.name == "nt": # only windows cv2.imshow("img", im) k = cv2.waitKey() @@ -2345,17 +2626,21 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: return train_dataset_group -def load_image(image_path): +def load_image(image_path, alpha=False): image = Image.open(image_path) - if not image.mode == "RGB": - image = image.convert("RGB") + if alpha: + if not image.mode == "RGBA": + image = image.convert("RGBA") + else: + if not image.mode == "RGB": + image = image.convert("RGB") img = np.array(image, np.uint8) return img # 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom) def trim_and_resize_if_required( - random_crop: bool, image: Image.Image, reso, resized_size: Tuple[int, int] + random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int] ) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]: image_height, image_width = image.shape[0:2] original_size = (image_width, image_height) # size before resize @@ -2386,8 +2671,53 @@ def trim_and_resize_if_required( return image, original_size, crop_ltrb +# for new_cache_latents +def load_images_and_masks_for_caching( + image_infos: List[ImageInfo], use_alpha_mask: bool, random_crop: bool +) -> Tuple[torch.Tensor, List[np.ndarray], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]: + r""" + requires image_infos to have: [absolute_path or image], bucket_reso, resized_size + + returns: image_tensor, alpha_masks, original_sizes, crop_ltrbs + + image_tensor: torch.Tensor = torch.Size([B, 3, H, W]), ...], normalized to [-1, 1] + alpha_masks: List[np.ndarray] = [np.ndarray([H, W]), ...], normalized to [0, 1] + original_sizes: List[Tuple[int, int]] = [(W, H), ...] + crop_ltrbs: List[Tuple[int, int, int, int]] = [(L, T, R, B), ...] + """ + images: List[torch.Tensor] = [] + alpha_masks: List[np.ndarray] = [] + original_sizes: List[Tuple[int, int]] = [] + crop_ltrbs: List[Tuple[int, int, int, int]] = [] + for info in image_infos: + image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) + # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 + image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) + + original_sizes.append(original_size) + crop_ltrbs.append(crop_ltrb) + + if use_alpha_mask: + if image.shape[2] == 4: + alpha_mask = image[:, :, 3] # [H,W] + alpha_mask = alpha_mask.astype(np.float32) / 255.0 + alpha_mask = torch.FloatTensor(alpha_mask) # [H,W] + else: + alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W] + else: + alpha_mask = None + alpha_masks.append(alpha_mask) + + image = image[:, :, :3] # remove alpha channel if exists + image = IMAGE_TRANSFORMS(image) + images.append(image) + + img_tensor = torch.stack(images, dim=0) + return img_tensor, alpha_masks, original_sizes, crop_ltrbs + + def cache_batch_latents( - vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool + vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, use_alpha_mask: bool, random_crop: bool ) -> None: r""" requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz @@ -2399,16 +2729,30 @@ def cache_batch_latents( latents_original_size and latents_crop_ltrb are also set """ images = [] + alpha_masks: List[np.ndarray] = [] for info in image_infos: - image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8) + image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) - image = IMAGE_TRANSFORMS(image) - images.append(image) info.latents_original_size = original_size info.latents_crop_ltrb = crop_ltrb + if use_alpha_mask: + if image.shape[2] == 4: + alpha_mask = image[:, :, 3] # [H,W] + alpha_mask = alpha_mask.astype(np.float32) / 255.0 + alpha_mask = torch.FloatTensor(alpha_mask) # [H,W] + else: + alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W] + else: + alpha_mask = None + alpha_masks.append(alpha_mask) + + image = image[:, :, :3] # remove alpha channel if exists + image = IMAGE_TRANSFORMS(image) + images.append(image) + img_tensors = torch.stack(images, dim=0) img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype) @@ -2422,17 +2766,25 @@ def cache_batch_latents( else: flipped_latents = [None] * len(latents) - for info, latent, flipped_latent in zip(image_infos, latents, flipped_latents): + for info, latent, flipped_latent, alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks): # check NaN if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()): raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") if cache_to_disk: - save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent) + save_latents_to_disk( + info.latents_npz, + latent, + info.latents_original_size, + info.latents_crop_ltrb, + flipped_latent, + alpha_mask, + ) else: info.latents = latent if flip_aug: info.latents_flipped = flipped_latent + info.alpha_mask = alpha_mask if not HIGH_VRAM: clean_memory_on_device(vae.device) @@ -2470,6 +2822,34 @@ def cache_batch_text_encoder_outputs( info.text_encoder_pool2 = pool2 +def cache_batch_text_encoder_outputs_sd3( + image_infos, tokenizer, text_encoders, max_token_length, cache_to_disk, input_ids, output_dtype +): + # make input_ids for each text encoder + l_tokens, g_tokens, t5_tokens = input_ids + + clip_l, clip_g, t5xxl = text_encoders + with torch.no_grad(): + b_lg_out, b_t5_out, b_pool = sd3_utils.get_cond_from_tokens( + l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, "cpu", output_dtype + ) + b_lg_out = b_lg_out.detach() + b_t5_out = b_t5_out.detach() + b_pool = b_pool.detach() + + for info, lg_out, t5_out, pool in zip(image_infos, b_lg_out, b_t5_out, b_pool): + # debug: NaN check + if torch.isnan(lg_out).any() or torch.isnan(t5_out).any() or torch.isnan(pool).any(): + raise RuntimeError(f"NaN detected in text encoder outputs: {info.absolute_path}") + + if cache_to_disk: + save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, lg_out, t5_out, pool) + else: + info.text_encoder_outputs1 = lg_out + info.text_encoder_outputs2 = t5_out + info.text_encoder_pool2 = pool + + def save_text_encoder_outputs_to_disk(npz_path, hidden_state1, hidden_state2, pool2): np.savez( npz_path, @@ -2792,6 +3172,7 @@ def get_sai_model_spec( lora: bool, textual_inversion: bool, is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA + sd3: str = None, ): timestamp = time.time() @@ -2825,6 +3206,7 @@ def get_sai_model_spec( tags=args.metadata_tags, timesteps=timesteps, clip_skip=args.clip_skip, # None or int + sd3=sd3, ) return metadata @@ -2920,6 +3302,12 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): default=1, help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power", ) + parser.add_argument( + "--fused_backward_pass", + action="store_true", + help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL" + + " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXLでのみ有効", + ) def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): @@ -3087,7 +3475,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: ) parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") parser.add_argument( - "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / grandient checkpointingを有効にする" + "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする" ) parser.add_argument( "--gradient_accumulation_steps", @@ -3170,6 +3558,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)", ) + parser.add_argument("--log_config", action="store_true", help="log training configuration / 学習設定をログに出力する") parser.add_argument( "--noise_offset", @@ -3379,6 +3768,42 @@ def add_masked_loss_arguments(parser: argparse.ArgumentParser): ) +def get_sanitized_config_or_none(args: argparse.Namespace): + # if `--log_config` is enabled, return args for logging. if not, return None. + # when `--log_config is enabled, filter out sensitive values from args + # if wandb is not enabled, the log is not exposed to the public, but it is fine to filter out sensitive values to be safe + + if not args.log_config: + return None + + sensitive_args = ["wandb_api_key", "huggingface_token"] + sensitive_path_args = [ + "pretrained_model_name_or_path", + "vae", + "tokenizer_cache_dir", + "train_data_dir", + "conditioning_data_dir", + "reg_data_dir", + "output_dir", + "logging_dir", + ] + filtered_args = {} + for k, v in vars(args).items(): + # filter out sensitive values and convert to string if necessary + if k not in sensitive_args + sensitive_path_args: + # Accelerate values need to have type `bool`,`str`, `float`, `int`, or `None`. + if v is None or isinstance(v, bool) or isinstance(v, str) or isinstance(v, float) or isinstance(v, int): + filtered_args[k] = v + # accelerate does not support lists + elif isinstance(v, list): + filtered_args[k] = f"{v}" + # accelerate does not support objects + elif isinstance(v, object): + filtered_args[k] = f"{v}" + + return filtered_args + + # verify command line args for training def verify_command_line_training_args(args: argparse.Namespace): # if wandb is enabled, the command line is exposed to the public @@ -3636,6 +4061,11 @@ def add_dataset_arguments( default=0, help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)", ) + parser.add_argument( + "--alpha_mask", + action="store_true", + help="use alpha channel as mask for training / 画像のアルファチャンネルをlossのマスクに使用する", + ) parser.add_argument( "--dataset_class", @@ -3846,6 +4276,14 @@ def get_optimizer(args, trainable_params): optimizer_type = "AdamW" optimizer_type = optimizer_type.lower() + if args.fused_backward_pass: + assert ( + optimizer_type == "Adafactor".lower() + ), "fused_backward_pass currently only works with optimizer_type Adafactor / fused_backward_passは現在optimizer_type Adafactorでのみ機能します" + assert ( + args.gradient_accumulation_steps == 1 + ), "fused_backward_pass does not work with gradient_accumulation_steps > 1 / fused_backward_passはgradient_accumulation_steps>1では機能しません" + # 引数を分解する optimizer_kwargs = {} if args.optimizer_args is not None and len(args.optimizer_args) > 0: @@ -5302,8 +5740,8 @@ def sample_image_inference( controlnet_image=controlnet_image, ) - with torch.cuda.device(torch.cuda.current_device()): - torch.cuda.empty_cache() + #with torch.cuda.device(torch.cuda.current_device()): + # torch.cuda.empty_cache() image = pipeline.latents_to_image(latents)[0] @@ -5390,6 +5828,8 @@ def add(self, *, epoch: int, step: int, loss: float) -> None: if epoch == 0: self.loss_list.append(loss) else: + while len(self.loss_list) <= step: + self.loss_list.append(0.0) self.loss_total -= self.loss_list[step] self.loss_list[step] = loss self.loss_total += loss diff --git a/requirements.txt b/requirements.txt index 977c5cd91..e99775b8a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ transformers==4.36.2 diffusers[torch]==0.25.0 ftfy==6.1.1 # albumentations==1.3.0 -opencv-python==4.8.1.78 +opencv-python==4.7.0.68 einops==0.7.0 pytorch-lightning==1.9.0 bitsandbytes==0.43.0 diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py new file mode 100644 index 000000000..ffa0d46de --- /dev/null +++ b/sd3_minimal_inference.py @@ -0,0 +1,351 @@ +# Minimum Inference Code for SD3 + +import argparse +import datetime +import math +import os +import random +from typing import Optional, Tuple +import numpy as np + +import torch +from safetensors.torch import safe_open, load_file +from tqdm import tqdm +from PIL import Image + +from library.device_utils import init_ipex, get_preferred_device + +init_ipex() + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from library import sd3_models, sd3_utils + + +def get_noise(seed, latent): + generator = torch.manual_seed(seed) + return torch.randn(latent.size(), dtype=torch.float32, layout=latent.layout, generator=generator, device="cpu").to(latent.dtype) + + +def get_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps): + start = sampling.timestep(sampling.sigma_max) + end = sampling.timestep(sampling.sigma_min) + timesteps = torch.linspace(start, end, steps) + sigs = [] + for x in range(len(timesteps)): + ts = timesteps[x] + sigs.append(sampling.sigma(ts)) + sigs += [0.0] + return torch.FloatTensor(sigs) + + +def max_denoise(model_sampling, sigmas): + max_sigma = float(model_sampling.sigma_max) + sigma = float(sigmas[0]) + return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma + + +def do_sample( + height: int, + width: int, + initial_latent: Optional[torch.Tensor], + seed: int, + cond: Tuple[torch.Tensor, torch.Tensor], + neg_cond: Tuple[torch.Tensor, torch.Tensor], + mmdit: sd3_models.MMDiT, + steps: int, + guidance_scale: float, + dtype: torch.dtype, + device: str, +): + if initial_latent is None: + # latent = torch.ones(1, 16, height // 8, width // 8, device=device) * 0.0609 # this seems to be a bug in the original code. thanks to furusu for pointing it out + latent = torch.zeros(1, 16, height // 8, width // 8, device=device) + else: + latent = initial_latent + + latent = latent.to(dtype).to(device) + + noise = get_noise(seed, latent).to(device) + + model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3 + + sigmas = get_sigmas(model_sampling, steps).to(device) + # sigmas = sigmas[int(steps * (1 - denoise)) :] # do not support i2i + + # conditioning = fix_cond(conditioning) + # neg_cond = fix_cond(neg_cond) + # extra_args = {"cond": cond, "uncond": neg_cond, "cond_scale": guidance_scale} + + noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent, max_denoise(model_sampling, sigmas)) + + c_crossattn = torch.cat([cond[0], neg_cond[0]]).to(device).to(dtype) + y = torch.cat([cond[1], neg_cond[1]]).to(device).to(dtype) + + x = noise_scaled.to(device).to(dtype) + # print(x.shape) + + with torch.no_grad(): + for i in tqdm(range(len(sigmas) - 1)): + sigma_hat = sigmas[i] + + timestep = model_sampling.timestep(sigma_hat).float() + timestep = torch.FloatTensor([timestep, timestep]).to(device) + + x_c_nc = torch.cat([x, x], dim=0) + # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape) + + model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) + model_output = model_output.float() + batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) + + pos_out, neg_out = batched.chunk(2) + denoised = neg_out + (pos_out - neg_out) * guidance_scale + # print(denoised.shape) + + # d = to_d(x, sigma_hat, denoised) + dims_to_append = x.ndim - sigma_hat.ndim + sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append] + # print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape) + """Converts a denoiser output to a Karras ODE derivative.""" + d = (x - denoised) / sigma_hat_dims + + dt = sigmas[i + 1] - sigma_hat + + # Euler method + x = x + d * dt + x = x.to(dtype) + + latent = x + scale_factor = 1.5305 + shift_factor = 0.0609 + # def process_out(self, latent): + # return (latent / self.scale_factor) + self.shift_factor + latent = (latent / scale_factor) + shift_factor + return latent + + +if __name__ == "__main__": + target_height = 1024 + target_width = 1024 + + # steps = 50 # 28 # 50 + guidance_scale = 5 + # seed = 1 # None # 1 + + device = get_preferred_device() + + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--clip_g", type=str, required=False) + parser.add_argument("--clip_l", type=str, required=False) + parser.add_argument("--t5xxl", type=str, required=False) + parser.add_argument("--prompt", type=str, default="A photo of a cat") + # parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders + parser.add_argument("--negative_prompt", type=str, default="") + parser.add_argument("--output_dir", type=str, default=".") + parser.add_argument("--do_not_use_t5xxl", action="store_true") + parser.add_argument("--attn_mode", type=str, default="torch", help="torch (SDPA) or xformers. default: torch") + parser.add_argument("--fp16", action="store_true") + parser.add_argument("--bf16", action="store_true") + parser.add_argument("--seed", type=int, default=1) + parser.add_argument("--steps", type=int, default=50) + # parser.add_argument( + # "--lora_weights", + # type=str, + # nargs="*", + # default=[], + # help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)", + # ) + # parser.add_argument("--interactive", action="store_true") + args = parser.parse_args() + + seed = args.seed + steps = args.steps + + sd3_dtype = torch.float32 + if args.fp16: + sd3_dtype = torch.float16 + elif args.bf16: + sd3_dtype = torch.bfloat16 + + # TODO test with separated safetenors files for each model + + # load state dict + logger.info(f"Loading SD3 models from {args.ckpt_path}...") + state_dict = load_file(args.ckpt_path) + + if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_g: remove prefix "text_encoders.clip_g." + logger.info("clip_g is included in the checkpoint") + clip_g_sd = {} + prefix = "text_encoders.clip_g." + for k, v in list(state_dict.items()): + if k.startswith(prefix): + clip_g_sd[k[len(prefix) :]] = state_dict.pop(k) + else: + logger.info(f"Lodaing clip_g from {args.clip_g}...") + clip_g_sd = load_file(args.clip_g) + for key in list(clip_g_sd.keys()): + clip_g_sd["transformer." + key] = clip_g_sd.pop(key) + + if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_l: remove prefix "text_encoders.clip_l." + logger.info("clip_l is included in the checkpoint") + clip_l_sd = {} + prefix = "text_encoders.clip_l." + for k, v in list(state_dict.items()): + if k.startswith(prefix): + clip_l_sd[k[len(prefix) :]] = state_dict.pop(k) + else: + logger.info(f"Lodaing clip_l from {args.clip_l}...") + clip_l_sd = load_file(args.clip_l) + for key in list(clip_l_sd.keys()): + clip_l_sd["transformer." + key] = clip_l_sd.pop(key) + + if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict: + # found t5xxl: remove prefix "text_encoders.t5xxl." + logger.info("t5xxl is included in the checkpoint") + if not args.do_not_use_t5xxl: + t5xxl_sd = {} + prefix = "text_encoders.t5xxl." + for k, v in list(state_dict.items()): + if k.startswith(prefix): + t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k) + else: + logger.info("but not used") + for key in list(state_dict.keys()): + if key.startswith("text_encoders.t5xxl."): + state_dict.pop(key) + t5xxl_sd = None + elif args.t5xxl: + assert not args.do_not_use_t5xxl, "t5xxl is not used but specified" + logger.info(f"Lodaing t5xxl from {args.t5xxl}...") + t5xxl_sd = load_file(args.t5xxl) + for key in list(t5xxl_sd.keys()): + t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key) + else: + logger.info("t5xxl is not used") + t5xxl_sd = None + + use_t5xxl = t5xxl_sd is not None + + # MMDiT and VAE + vae_sd = {} + vae_prefix = "first_stage_model." + mmdit_prefix = "model.diffusion_model." + for k, v in list(state_dict.items()): + if k.startswith(vae_prefix): + vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k) + elif k.startswith(mmdit_prefix): + state_dict[k[len(mmdit_prefix) :]] = state_dict.pop(k) + + # load tokenizers + logger.info("Loading tokenizers...") + tokenizer = sd3_models.SD3Tokenizer(use_t5xxl) # combined tokenizer + + # load models + # logger.info("Create MMDiT from SD3 checkpoint...") + # mmdit = sd3_utils.create_mmdit_from_sd3_checkpoint(state_dict) + logger.info("Create MMDiT") + mmdit = sd3_models.create_mmdit_sd3_medium_configs(args.attn_mode) + + logger.info("Loading state dict...") + info = mmdit.load_state_dict(state_dict) + logger.info(f"Loaded MMDiT: {info}") + + logger.info(f"Move MMDiT to {device} and {sd3_dtype}...") + mmdit.to(device, dtype=sd3_dtype) + mmdit.eval() + + # load VAE + logger.info("Create VAE") + vae = sd3_models.SDVAE() + logger.info("Loading state dict...") + info = vae.load_state_dict(vae_sd) + logger.info(f"Loaded VAE: {info}") + + logger.info(f"Move VAE to {device} and {sd3_dtype}...") + vae.to(device, dtype=sd3_dtype) + vae.eval() + + # load text encoders + logger.info("Create clip_l") + clip_l = sd3_models.create_clip_l(device, sd3_dtype, clip_l_sd) + + logger.info("Loading state dict...") + info = clip_l.load_state_dict(clip_l_sd) + logger.info(f"Loaded clip_l: {info}") + + logger.info(f"Move clip_l to {device} and {sd3_dtype}...") + clip_l.to(device, dtype=sd3_dtype) + clip_l.eval() + logger.info(f"Set attn_mode to {args.attn_mode}...") + clip_l.set_attn_mode(args.attn_mode) + + logger.info("Create clip_g") + clip_g = sd3_models.create_clip_g(device, sd3_dtype, clip_g_sd) + + logger.info("Loading state dict...") + info = clip_g.load_state_dict(clip_g_sd) + logger.info(f"Loaded clip_g: {info}") + + logger.info(f"Move clip_g to {device} and {sd3_dtype}...") + clip_g.to(device, dtype=sd3_dtype) + clip_g.eval() + logger.info(f"Set attn_mode to {args.attn_mode}...") + clip_g.set_attn_mode(args.attn_mode) + + if use_t5xxl: + logger.info("Create t5xxl") + t5xxl = sd3_models.create_t5xxl(device, sd3_dtype, t5xxl_sd) + + logger.info("Loading state dict...") + info = t5xxl.load_state_dict(t5xxl_sd) + logger.info(f"Loaded t5xxl: {info}") + + logger.info(f"Move t5xxl to {device} and {sd3_dtype}...") + t5xxl.to(device, dtype=sd3_dtype) + # t5xxl.to("cpu", dtype=torch.float32) # run on CPU + t5xxl.eval() + logger.info(f"Set attn_mode to {args.attn_mode}...") + t5xxl.set_attn_mode(args.attn_mode) + else: + t5xxl = None + + # prepare embeddings + logger.info("Encoding prompts...") + # embeds, pooled_embed + lg_out, t5_out, pooled = sd3_utils.get_cond(args.prompt, tokenizer, clip_l, clip_g, t5xxl) + cond = torch.cat([lg_out, t5_out], dim=-2), pooled + + lg_out, t5_out, pooled = sd3_utils.get_cond(args.negative_prompt, tokenizer, clip_l, clip_g, t5xxl) + neg_cond = torch.cat([lg_out, t5_out], dim=-2), pooled + + # generate image + logger.info("Generating image...") + latent_sampled = do_sample( + target_height, target_width, None, seed, cond, neg_cond, mmdit, steps, guidance_scale, sd3_dtype, device + ) + + # latent to image + with torch.no_grad(): + image = vae.decode(latent_sampled) + image = image.float() + image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] + decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2) + decoded_np = decoded_np.astype(np.uint8) + out_image = Image.fromarray(decoded_np) + + # save image + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png") + out_image.save(output_path) + + logger.info(f"Saved image to {output_path}") diff --git a/sd3_train.py b/sd3_train.py new file mode 100644 index 000000000..e2f622e47 --- /dev/null +++ b/sd3_train.py @@ -0,0 +1,981 @@ +# training with captions + +import argparse +import copy +import math +import os +from multiprocessing import Value +from typing import List +import toml + +from tqdm import tqdm + +import torch +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from accelerate.utils import set_seed +from diffusers import DDPMScheduler +from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils +from library.sdxl_train_util import match_mixed_precision + +# , sdxl_model_util + +import library.train_util as train_util + +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +import library.config_util as config_util + +# import library.sdxl_train_util as sdxl_train_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.custom_train_functions as custom_train_functions + +# from library.custom_train_functions import ( +# apply_snr_weight, +# prepare_scheduler_for_custom_training, +# scale_v_prediction_loss_like_noise_prediction, +# add_v_prediction_like_loss, +# apply_debiased_estimation, +# apply_masked_loss, +# ) + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + # sdxl_train_util.verify_sdxl_training_args(args) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) + + assert ( + not args.weighted_captions + ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + # assert ( + # not args.train_text_encoder or not args.cache_text_encoder_outputs + # ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" + + # # training text encoder is not supported + # assert ( + # not args.train_text_encoder + # ), "training text encoder is not supported currently / text encoderの学習は現在サポートされていません" + + # training without text encoder cache is not supported + assert ( + args.cache_text_encoder_outputs + ), "training without text encoder cache is not supported currently / text encoderのキャッシュなしの学習は現在サポートされていません" + + # if args.block_lr: + # block_lrs = [float(lr) for lr in args.block_lr.split(",")] + # assert ( + # len(block_lrs) == UNET_NUM_BLOCKS_FOR_BLOCK_LR + # ), f"block_lr must have {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / block_lrは{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値を指定してください" + # else: + # block_lrs = None + + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + # load tokenizer + sd3_tokenizer = sd3_models.SD3Tokenizer() + + # prepare caching strategy + if args.new_caching: + latents_caching_strategy = sd3_train_utils.Sd3LatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check + ) + else: + latents_caching_strategy = None + train_util.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=[sd3_tokenizer]) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args, [sd3_tokenizer]) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(8) # TODO これでいいか確認 + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group, True) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = weight_dtype # torch.float32 if args.no_half_vae else weight_dtype # SD3 VAE works with fp16 + + t5xxl_dtype = weight_dtype + if args.t5xxl_dtype is not None: + if args.t5xxl_dtype == "fp16": + t5xxl_dtype = torch.float16 + elif args.t5xxl_dtype == "bf16": + t5xxl_dtype = torch.bfloat16 + elif args.t5xxl_dtype == "fp32" or args.t5xxl_dtype == "float": + t5xxl_dtype = torch.float32 + else: + raise ValueError(f"unexpected t5xxl_dtype: {args.t5xxl_dtype}") + t5xxl_device = accelerator.device if args.t5xxl_device is None else args.t5xxl_device + + clip_dtype = weight_dtype # if not args.train_text_encoder else None + + # モデルを読み込む + attn_mode = "xformers" if args.xformers else "torch" + + assert ( + attn_mode == "torch" + ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" + + # SD3 state dict may contain multiple models, so we need to load it and extract one by one. annoying. + logger.info(f"Loading SD3 models from {args.pretrained_model_name_or_path}") + device_to_load = accelerator.device if args.lowram else "cpu" + sd3_state_dict = sd3_utils.load_safetensors( + args.pretrained_model_name_or_path, device_to_load, args.disable_mmap_load_safetensors + ) + + # load VAE for caching latents + vae: sd3_models.SDVAE = None + if cache_latents: + vae = sd3_train_utils.load_target_model("vae", args, sd3_state_dict, accelerator, attn_mode, vae_dtype, device_to_load) + vae.to(accelerator.device, dtype=vae_dtype) + vae.requires_grad_(False) + vae.eval() + + if not args.new_caching: + vae_wrapper = sd3_models.VAEWrapper(vae) # make SD/SDXL compatible + with torch.no_grad(): + train_dataset_group.cache_latents( + vae_wrapper, + args.vae_batch_size, + args.cache_latents_to_disk, + accelerator.is_main_process, + file_suffix="_sd3.npz", + ) + else: + latents_caching_strategy.set_vae(vae) + train_dataset_group.new_cache_latents(accelerator.is_main_process, latents_caching_strategy) + vae.to("cpu") # if no sampling, vae can be deleted + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # load clip_l, clip_g, t5xxl for caching text encoder outputs + # # models are usually loaded on CPU and moved to GPU later. This is to avoid OOM on GPU0. + # mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model( + # args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype + # ) + clip_l = sd3_train_utils.load_target_model("clip_l", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load) + clip_g = sd3_train_utils.load_target_model("clip_g", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load) + assert clip_l is not None, "clip_l is required / clip_lは必須です" + assert clip_g is not None, "clip_g is required / clip_gは必須です" + + t5xxl = sd3_train_utils.load_target_model("t5xxl", args, sd3_state_dict, accelerator, attn_mode, t5xxl_dtype, device_to_load) + # logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype) + + # 学習を準備する:モデルを適切な状態にする + train_clip_l = False + train_clip_g = False + train_t5xxl = False + + # if args.train_text_encoder: + # # TODO each option for two text encoders? + # accelerator.print("enable text encoder training") + # if args.gradient_checkpointing: + # text_encoder1.gradient_checkpointing_enable() + # text_encoder2.gradient_checkpointing_enable() + # lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train + # lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train + # train_clip_l = lr_te1 != 0 + # train_clip_g = lr_te2 != 0 + + # # caching one text encoder output is not supported + # if not train_clip_l: + # text_encoder1.to(weight_dtype) + # if not train_clip_g: + # text_encoder2.to(weight_dtype) + # text_encoder1.requires_grad_(train_clip_l) + # text_encoder2.requires_grad_(train_clip_g) + # text_encoder1.train(train_clip_l) + # text_encoder2.train(train_clip_g) + # else: + clip_l.to(weight_dtype) + clip_g.to(weight_dtype) + clip_l.requires_grad_(False) + clip_g.requires_grad_(False) + clip_l.eval() + clip_g.eval() + if t5xxl is not None: + t5xxl.to(t5xxl_dtype) + t5xxl.requires_grad_(False) + t5xxl.eval() + + # TextEncoderの出力をキャッシュする + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad + + with torch.no_grad(), accelerator.autocast(): + train_dataset_group.cache_text_encoder_outputs_sd3( + sd3_tokenizer, + (clip_l, clip_g, t5xxl), + (accelerator.device, accelerator.device, t5xxl_device), + None, + (None, None, None), + args.cache_text_encoder_outputs_to_disk, + accelerator.is_main_process, + args.text_encoder_batch_size, + ) + + # TODO we can delete text encoders after caching + accelerator.wait_for_everyone() + + # load MMDIT + # if full_fp16/bf16, model_dtype is casted to fp16/bf16. If not, model_dtype is None (float32). + # by loading with model_dtype, we can reduce memory usage. + model_dtype = match_mixed_precision(args, weight_dtype) # None (default) or fp16/bf16 (full_xxxx) + mmdit = sd3_train_utils.load_target_model("mmdit", args, sd3_state_dict, accelerator, attn_mode, model_dtype, device_to_load) + if args.gradient_checkpointing: + mmdit.enable_gradient_checkpointing() + + train_mmdit = args.learning_rate != 0 + mmdit.requires_grad_(train_mmdit) + if not train_mmdit: + mmdit.to(accelerator.device, dtype=weight_dtype) # because of mmdie will not be prepared + + if not cache_latents: + # load VAE here if not cached + vae = sd3_train_utils.load_target_model("vae", args, sd3_state_dict, accelerator, attn_mode, vae_dtype, device_to_load) + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=vae_dtype) + + training_models = [] + params_to_optimize = [] + # if train_unet: + training_models.append(mmdit) + # if block_lrs is None: + params_to_optimize.append({"params": list(mmdit.parameters()), "lr": args.learning_rate}) + # else: + # params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs)) + + # if train_clip_l: + # training_models.append(text_encoder1) + # params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate}) + # if train_clip_g: + # training_models.append(text_encoder2) + # params_to_optimize.append({"params": list(text_encoder2.parameters()), "lr": args.learning_rate_te2 or args.learning_rate}) + + # calculate number of trainable parameters + n_params = 0 + for group in params_to_optimize: + for p in group["params"]: + n_params += p.numel() + + accelerator.print(f"train mmdit: {train_mmdit}") # , text_encoder1: {train_clip_l}, text_encoder2: {train_clip_g}") + accelerator.print(f"number of models: {len(training_models)}") + accelerator.print(f"number of trainable parameters: {n_params}") + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + if args.fused_optimizer_groups: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters. + # This balances memory usage and management complexity. + + # calculate total number of parameters + n_total_params = sum(len(params["params"]) for params in params_to_optimize) + params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups) + + # split params into groups, keeping the learning rate the same for all params in a group + # this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders) + grouped_params = [] + param_group = [] + param_group_lr = -1 + for group in params_to_optimize: + lr = group["lr"] + for p in group["params"]: + # if the learning rate is different for different params, start a new group + if lr != param_group_lr: + if param_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = [] + param_group_lr = lr + + param_group.append(p) + + # if the group has enough parameters, start a new group + if len(param_group) == params_per_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = [] + param_group_lr = -1 + + if param_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups") + + else: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + + # dataloaderを準備する + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + if args.fused_optimizer_groups: + # prepare lr schedulers for each optimizer + lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + mmdit.to(weight_dtype) + clip_l.to(weight_dtype) + clip_g.to(weight_dtype) + if t5xxl is not None: + t5xxl.to(weight_dtype) # TODO check works with fp16 or not + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + mmdit.to(weight_dtype) + clip_l.to(weight_dtype) + clip_g.to(weight_dtype) + if t5xxl is not None: + t5xxl.to(weight_dtype) + + # TODO check if this is necessary. SD3 uses pool for clip_l and clip_g + # # freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer + # if train_clip_l: + # text_encoder1.text_model.encoder.layers[-1].requires_grad_(False) + # text_encoder1.text_model.final_layer_norm.requires_grad_(False) + + if args.deepspeed: + ds_model = deepspeed_utils.prepare_deepspeed_model( + args, + mmdit=mmdit, + # mmdie=mmdit if train_mmdit else None, + # text_encoder1=text_encoder1 if train_clip_l else None, + # text_encoder2=text_encoder2 if train_clip_g else None, + ) + # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] + + else: + # acceleratorがなんかよろしくやってくれるらしい + if train_mmdit: + mmdit = accelerator.prepare(mmdit) + # if train_clip_l: + # text_encoder1 = accelerator.prepare(text_encoder1) + # if train_clip_g: + # text_encoder2 = accelerator.prepare(text_encoder2) + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + + # TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する + if args.cache_text_encoder_outputs: + # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 + clip_l.to("cpu", dtype=torch.float32) + clip_g.to("cpu", dtype=torch.float32) + if t5xxl is not None: + t5xxl.to("cpu", dtype=torch.float32) + clean_memory_on_device(accelerator.device) + else: + # make sure Text Encoders are on GPU + # TODO support CPU for text encoders + clip_l.to(accelerator.device) + clip_g.to(accelerator.device) + if t5xxl is not None: + t5xxl.to(accelerator.device) + + # TODO cache sample prompt's embeddings to free text encoder's memory + if args.cache_text_encoder_outputs: + if not args.save_t5xxl: + t5xxl = None # free memory + clean_memory_on_device(accelerator.device) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + # -> But we think it's ok to patch accelerator even if deepspeed is enabled. + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + parameter.register_post_accumulate_grad_hook(__grad_hook) + + elif args.fused_optimizer_groups: + # prepare for additional optimizers and lr schedulers + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + # counters are used to determine when to step the optimizer + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def optimizer_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(optimizer_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # accelerator.print( + # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + # ) + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + # noise_scheduler = DDPMScheduler( + # beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + # ) + + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + + # prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + # if args.zero_terminal_snr: + # custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + # # For --sample_at_first + # sd3_train_utils.sample_images( + # accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], mmdit + # ) + + # following function will be moved to sd3_train_utils + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + def compute_density_for_timestep_sampling( + weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None + ): + """Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): + """Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + elif weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = torch.ones_like(sigmas) + return weighting + + loss_recorder = train_util.LossRecorder() + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for m in training_models: + m.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + + if args.fused_optimizer_groups: + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step + + with accelerator.accumulate(*training_models): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + with torch.no_grad(): + # encode images to latents. images are [-1, 1] + latents = vae.encode(batch["images"].to(vae_dtype)).to(weight_dtype) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + # latents = latents * sdxl_model_util.VAE_SCALE_FACTOR + latents = sd3_models.SDVAE.process_in(latents) + + if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: + # not cached, get text encoder outputs + # XXX This does not work yet + input_ids_clip_l, input_ids_clip_g, input_ids_t5xxl = batch["input_ids"] + with torch.set_grad_enabled(args.train_text_encoder): + # TODO support weighted captions + # TODO support length > 75 + input_ids_clip_l = input_ids_clip_l.to(accelerator.device) + input_ids_clip_g = input_ids_clip_g.to(accelerator.device) + input_ids_t5xxl = input_ids_t5xxl.to(accelerator.device) + + # get text encoder outputs: outputs are concatenated + context, pool = sd3_utils.get_cond_from_tokens( + input_ids_clip_l, input_ids_clip_g, input_ids_t5xxl, clip_l, clip_g, t5xxl + ) + else: + # encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) + # encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) + # pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) + # TODO this reuses SDXL keys, it should be fixed + lg_out = batch["text_encoder_outputs1_list"] + t5_out = batch["text_encoder_outputs2_list"] + pool = batch["text_encoder_pool2_list"] + context = torch.cat([lg_out, t5_out], dim=-2) + + # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = noise_scheduler_copy.timesteps[indices].to(device=accelerator.device) + + # Add noise according to flow matching. + sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + + # debug: NaN check for all inputs + if torch.any(torch.isnan(noisy_model_input)): + accelerator.print("NaN found in noisy_model_input, replacing with zeros") + noisy_model_input = torch.nan_to_num(noisy_model_input, 0, out=noisy_model_input) + if torch.any(torch.isnan(context)): + accelerator.print("NaN found in context, replacing with zeros") + context = torch.nan_to_num(context, 0, out=context) + if torch.any(torch.isnan(pool)): + accelerator.print("NaN found in pool, replacing with zeros") + pool = torch.nan_to_num(pool, 0, out=pool) + + # call model + with accelerator.autocast(): + model_pred = mmdit(noisy_model_input, timesteps, context=context, y=pool) + + # Follow: Section 5 of https://arxiv.org/abs/2206.00364. + # Preconditioning of the model outputs. + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = latents + + # Compute regular loss. TODO simplify this + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + 1, + ) + loss = loss.mean() + + accelerator.backward(loss) + + if not (args.fused_backward_pass or args.fused_optimizer_groups): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.fused_optimizer_groups: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + # sdxl_train_util.sample_images( + # accelerator, + # args, + # None, + # global_step, + # accelerator.device, + # vae, + # [tokenizer1, tokenizer2], + # [text_encoder1, text_encoder2], + # mmdit, + # ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + sd3_train_utils.save_sd3_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(clip_l) if args.save_clip else None, + accelerator.unwrap_model(clip_g) if args.save_clip else None, + accelerator.unwrap_model(t5xxl) if args.save_t5xxl else None, + accelerator.unwrap_model(mmdit), + vae, + ) + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if args.logging_dir is not None: + logs = {"loss": current_loss} + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_mmdit) + + accelerator.log(logs, step=global_step) + + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + sd3_train_utils.save_sd3_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(clip_l) if args.save_clip else None, + accelerator.unwrap_model(clip_g) if args.save_clip else None, + accelerator.unwrap_model(t5xxl) if args.save_t5xxl else None, + accelerator.unwrap_model(mmdit), + vae, + ) + + # sdxl_train_util.sample_images( + # accelerator, + # args, + # epoch + 1, + # global_step, + # accelerator.device, + # vae, + # [tokenizer1, tokenizer2], + # [text_encoder1, text_encoder2], + # mmdit, + # ) + + is_main_process = accelerator.is_main_process + # if is_main_process: + mmdit = accelerator.unwrap_model(mmdit) + clip_l = accelerator.unwrap_model(clip_l) + clip_g = accelerator.unwrap_model(clip_g) + if t5xxl is not None: + t5xxl = accelerator.unwrap_model(t5xxl) + + accelerator.end_training() + + if args.save_state or args.save_state_on_train_end: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + sd3_train_utils.save_sd3_model_on_train_end( + args, + save_dtype, + epoch, + global_step, + clip_l if args.save_clip else None, + clip_g if args.save_clip else None, + t5xxl if args.save_t5xxl else None, + mmdit, + vae, + ) + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + sd3_train_utils.add_sd3_training_arguments(parser) + + # parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") + + # TE training is disabled temporarily + # parser.add_argument( + # "--learning_rate_te1", + # type=float, + # default=None, + # help="learning rate for text encoder 1 (ViT-L) / text encoder 1 (ViT-L)の学習率", + # ) + # parser.add_argument( + # "--learning_rate_te2", + # type=float, + # default=None, + # help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率", + # ) + + # parser.add_argument( + # "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する" + # ) + # parser.add_argument( + # "--no_half_vae", + # action="store_true", + # help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + # ) + # parser.add_argument( + # "--block_lr", + # type=str, + # default=None, + # help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / " + # + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値", + # ) + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", + ) + + parser.add_argument("--new_caching", action="store_true", help="use new caching method / 新しいキャッシング方法を使う") + parser.add_argument( + "--skip_latents_validity_check", + action="store_true", + help="skip latents validity check / latentsの正当性チェックをスキップする", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/sdxl_train.py b/sdxl_train.py index 46d7860be..ae92d6a3d 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -272,7 +272,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # 学習を準備する:モデルを適切な状態にする if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - train_unet = args.learning_rate > 0 + train_unet = args.learning_rate != 0 train_text_encoder1 = False train_text_encoder2 = False @@ -284,8 +284,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): text_encoder2.gradient_checkpointing_enable() lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train - train_text_encoder1 = lr_te1 > 0 - train_text_encoder2 = lr_te2 > 0 + train_text_encoder1 = lr_te1 != 0 + train_text_encoder2 = lr_te2 != 0 # caching one text encoder output is not supported if not train_text_encoder1: @@ -345,8 +345,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # calculate number of trainable parameters n_params = 0 - for params in params_to_optimize: - for p in params["params"]: + for group in params_to_optimize: + for p in group["params"]: n_params += p.numel() accelerator.print(f"train unet: {train_unet}, text_encoder1: {train_text_encoder1}, text_encoder2: {train_text_encoder2}") @@ -355,7 +355,53 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + + if args.fused_optimizer_groups: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters. + # This balances memory usage and management complexity. + + # calculate total number of parameters + n_total_params = sum(len(params["params"]) for params in params_to_optimize) + params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups) + + # split params into groups, keeping the learning rate the same for all params in a group + # this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders) + grouped_params = [] + param_group = [] + param_group_lr = -1 + for group in params_to_optimize: + lr = group["lr"] + for p in group["params"]: + # if the learning rate is different for different params, start a new group + if lr != param_group_lr: + if param_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = [] + param_group_lr = lr + + param_group.append(p) + + # if the group has enough parameters, start a new group + if len(param_group) == params_per_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = [] + param_group_lr = -1 + + if param_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups") + + else: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) # dataloaderを準備する # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 @@ -382,7 +428,12 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): train_dataset_group.set_max_train_steps(args.max_train_steps) # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + if args.fused_optimizer_groups: + # prepare lr schedulers for each optimizer + lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする if args.full_fp16: @@ -450,6 +501,57 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # resumeする train_util.resume_from_local_or_hf_if_specified(accelerator, args) + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + parameter.register_post_accumulate_grad_hook(__grad_hook) + + elif args.fused_optimizer_groups: + # prepare for additional optimizers and lr schedulers + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + # counters are used to determine when to step the optimizer + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def optimizer_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(optimizer_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) @@ -487,7 +589,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) # For --sample_at_first sdxl_train_util.sample_images( @@ -504,6 +610,10 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): for step, batch in enumerate(train_dataloader): current_step.value = global_step + + if args.fused_optimizer_groups: + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step + with accelerator.accumulate(*training_models): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) @@ -582,7 +692,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -600,8 +712,10 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): or args.masked_loss ): # do not mean over batch dimension for snr weight or scale v-pred loss - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) - if args.masked_loss: + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -616,18 +730,28 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = loss.mean() # mean over batch dimension else: - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c + ) accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = [] - for m in training_models: - params_to_clip.extend(m.parameters()) - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) + if not (args.fused_backward_pass or args.fused_optimizer_groups): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.fused_optimizer_groups: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: @@ -736,7 +860,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.end_training() - if args.save_state or args.save_state_on_train_end: + if args.save_state or args.save_state_on_train_end: train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す @@ -805,6 +929,12 @@ def setup_parser() -> argparse.ArgumentParser: help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / " + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値", ) + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", + ) return parser diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index f89c3628f..5ff060a9f 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -15,6 +15,7 @@ import torch from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -288,6 +289,9 @@ def train(args): # acceleratorがなんかよろしくやってくれるらしい unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + if isinstance(unet, DDP): + unet._set_static_graph() # avoid error for multiple use of the parameter + if args.gradient_checkpointing: unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる else: @@ -353,7 +357,7 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs ) loss_recorder = train_util.LossRecorder() @@ -439,7 +443,9 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -458,7 +464,9 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight @@ -477,7 +485,7 @@ def remove_model(old_ckpt_name): accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = unet.get_trainable_params() + params_to_clip = accelerator.unwrap_model(unet).get_trainable_params() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index e85e978c1..292a0463a 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -324,7 +324,7 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs ) loss_recorder = train_util.LossRecorder() diff --git a/train_controlnet.py b/train_controlnet.py index f4c94e8d9..c9ac6c5a8 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -5,7 +5,8 @@ import random import time from multiprocessing import Value -from types import SimpleNamespace + +# from omegaconf import OmegaConf import toml from tqdm import tqdm @@ -13,6 +14,7 @@ import torch from library import deepspeed_utils from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -148,8 +150,10 @@ def train(args): "in_channels": 4, "layers_per_block": 2, "mid_block_scale_factor": 1, + "mid_block_type": "UNetMidBlock2DCrossAttn", "norm_eps": 1e-05, "norm_num_groups": 32, + "num_attention_heads": [5, 10, 20, 20], "num_class_embeds": None, "only_cross_attention": False, "out_channels": 4, @@ -179,8 +183,10 @@ def train(args): "in_channels": 4, "layers_per_block": 2, "mid_block_scale_factor": 1, + "mid_block_type": "UNetMidBlock2DCrossAttn", "norm_eps": 1e-05, "norm_num_groups": 32, + "num_attention_heads": 8, "out_channels": 4, "sample_size": 64, "up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"], @@ -193,7 +199,23 @@ def train(args): "resnet_time_scale_shift": "default", "projection_class_embeddings_input_dim": None, } - unet.config = SimpleNamespace(**unet.config) + # unet.config = OmegaConf.create(unet.config) + + # make unet.config iterable and accessible by attribute + class CustomConfig: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + def __getattr__(self, name): + if name in self.__dict__: + return self.__dict__[name] + else: + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + + def __contains__(self, name): + return name in self.__dict__ + + unet.config = CustomConfig(**unet.config) controlnet = ControlNetModel.from_unet(unet) @@ -226,7 +248,7 @@ def train(args): ) vae.to("cpu") clean_memory_on_device(accelerator.device) - + accelerator.wait_for_everyone() if args.gradient_checkpointing: @@ -235,7 +257,7 @@ def train(args): # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - trainable_params = controlnet.parameters() + trainable_params = list(controlnet.parameters()) _, _, optimizer = train_util.get_optimizer(args, trainable_params) @@ -344,7 +366,9 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) loss_recorder = train_util.LossRecorder() @@ -420,7 +444,9 @@ def remove_model(old_ckpt_name): ) # Sample a random timestep for each image - timesteps, huber_c = train_util.get_timesteps_and_huber_c(args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device) + timesteps, huber_c = train_util.get_timesteps_and_huber_c( + args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device + ) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) @@ -452,7 +478,9 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_db.py b/train_db.py index 1de504ed8..39d8ea6ed 100644 --- a/train_db.py +++ b/train_db.py @@ -290,7 +290,7 @@ def train(args): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs) # For --sample_at_first train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) @@ -359,7 +359,7 @@ def train(args): target = noise loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) - if args.masked_loss: + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) diff --git a/train_network.py b/train_network.py index c99d37247..7ba073855 100644 --- a/train_network.py +++ b/train_network.py @@ -53,7 +53,15 @@ def __init__(self): # TODO 他のスクリプトと共通化する def generate_step_logs( - self, args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None + self, + args: argparse.Namespace, + current_loss, + avr_loss, + lr_scheduler, + lr_descriptions, + keys_scaled=None, + mean_norm=None, + maximum_norm=None, ): logs = {"loss/current": current_loss, "loss/average": avr_loss} @@ -63,34 +71,26 @@ def generate_step_logs( logs["max_norm/max_key_norm"] = maximum_norm lrs = lr_scheduler.get_last_lr() - - if args.network_train_text_encoder_only or len(lrs) <= 2: # not block lr (or single block) - if args.network_train_unet_only: - logs["lr/unet"] = float(lrs[0]) - elif args.network_train_text_encoder_only: - logs["lr/textencoder"] = float(lrs[0]) + for i, lr in enumerate(lrs): + if lr_descriptions is not None: + lr_desc = lr_descriptions[i] else: - logs["lr/textencoder"] = float(lrs[0]) - logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder - - if ( - args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() - ): # tracking d*lr value of unet. - logs["lr/d*lr"] = ( - lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] + idx = i - (0 if args.network_train_unet_only else -1) + if idx == -1: + lr_desc = "textencoder" + else: + if len(lrs) > 2: + lr_desc = f"group{idx}" + else: + lr_desc = "unet" + + logs[f"lr/{lr_desc}"] = lr + + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): + # tracking d*lr value + logs[f"lr/d*lr/{lr_desc}"] = ( + lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) - else: - idx = 0 - if not args.network_train_unet_only: - logs["lr/textencoder"] = float(lrs[0]) - idx = 1 - - for i in range(idx, len(lrs)): - logs[f"lr/group{i}"] = float(lrs[i]) - if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): - logs[f"lr/d*lr/group{i}"] = ( - lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] - ) return logs @@ -323,6 +323,7 @@ def train(self, args): network.apply_to(text_encoder, unet, train_text_encoder, train_unet) if args.network_weights is not None: + # FIXME consider alpha of weights info = network.load_weights(args.network_weights) accelerator.print(f"load network weights from {args.network_weights}: {info}") @@ -338,12 +339,30 @@ def train(self, args): # 後方互換性を確保するよ try: - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) - except TypeError: - accelerator.print( - "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" - ) + results = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) + if type(results) is tuple: + trainable_params = results[0] + lr_descriptions = results[1] + else: + trainable_params = results + lr_descriptions = None + except TypeError as e: + # logger.warning(f"{e}") + # accelerator.print( + # "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" + # ) trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) + lr_descriptions = None + + # if len(trainable_params) == 0: + # accelerator.print("no trainable parameters found / 学習可能なパラメータが見つかりませんでした") + # for params in trainable_params: + # for k, v in params.items(): + # if type(v) == float: + # pass + # else: + # v = len(v) + # accelerator.print(f"trainable_params: {k} = {v}") optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) @@ -474,15 +493,26 @@ def train(self, args): # before resuming make hook for saving/loading to save/load the network weights only def save_model_hook(models, weights, output_dir): # pop weights of other models than network to save only network weights - if accelerator.is_main_process: + # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606 + if accelerator.is_main_process or args.deepspeed: remove_indices = [] for i, model in enumerate(models): if not isinstance(model, type(accelerator.unwrap_model(network))): remove_indices.append(i) for i in reversed(remove_indices): - weights.pop(i) + if len(weights) > i: + weights.pop(i) # print(f"save model hook: {len(weights)} weights will be saved") + # save current ecpoch and step + train_state_file = os.path.join(output_dir, "train_state.json") + # +1 is needed because the state is saved before current_step is set from global_step + logger.info(f"save train state to {train_state_file} at epoch {current_epoch.value} step {current_step.value+1}") + with open(train_state_file, "w", encoding="utf-8") as f: + json.dump({"current_epoch": current_epoch.value, "current_step": current_step.value + 1}, f) + + steps_from_state = None + def load_model_hook(models, input_dir): # remove models except network remove_indices = [] @@ -493,6 +523,15 @@ def load_model_hook(models, input_dir): models.pop(i) # print(f"load model hook: {len(models)} models will be loaded") + # load current epoch and step to + nonlocal steps_from_state + train_state_file = os.path.join(input_dir, "train_state.json") + if os.path.exists(train_state_file): + with open(train_state_file, "r", encoding="utf-8") as f: + data = json.load(f) + steps_from_state = data["current_step"] + logger.info(f"load train state from {train_state_file}: {data}") + accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) @@ -736,7 +775,54 @@ def load_model_hook(models, input_dir): if key in metadata: minimum_metadata[key] = metadata[key] - progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + # calculate steps to skip when resuming or starting from a specific step + initial_step = 0 + if args.initial_epoch is not None or args.initial_step is not None: + # if initial_epoch or initial_step is specified, steps_from_state is ignored even when resuming + if steps_from_state is not None: + logger.warning( + "steps from the state is ignored because initial_step is specified / initial_stepが指定されているため、stateからのステップ数は無視されます" + ) + if args.initial_step is not None: + initial_step = args.initial_step + else: + # num steps per epoch is calculated by num_processes and gradient_accumulation_steps + initial_step = (args.initial_epoch - 1) * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + else: + # if initial_epoch and initial_step are not specified, steps_from_state is used when resuming + if steps_from_state is not None: + initial_step = steps_from_state + steps_from_state = None + + if initial_step > 0: + assert ( + args.max_train_steps > initial_step + ), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}" + + progress_bar = tqdm( + range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps" + ) + + epoch_to_start = 0 + if initial_step > 0: + if args.skip_until_initial_step: + # if skip_until_initial_step is specified, load data and discard it to ensure the same data is used + if not args.resume: + logger.info( + f"initial_step is specified but not resuming. lr scheduler will be started from the beginning / initial_stepが指定されていますがresumeしていないため、lr schedulerは最初から始まります" + ) + logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします") + initial_step *= args.gradient_accumulation_steps + + # set epoch to start to make initial_step less than len(train_dataloader) + epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + else: + # if not, only epoch no is skipped for informative purpose + epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + initial_step = 0 # do not skip + global_step = 0 noise_scheduler = DDPMScheduler( @@ -753,7 +839,9 @@ def load_model_hook(models, input_dir): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "network_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "network_train" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) loss_recorder = train_util.LossRecorder() @@ -793,7 +881,13 @@ def remove_model(old_ckpt_name): self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) # training loop - for epoch in range(num_train_epochs): + if initial_step > 0: # only if skip_until_initial_step is specified + for skip_epoch in range(epoch_to_start): # skip epochs + logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}") + initial_step -= len(train_dataloader) + global_step = initial_step + + for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -801,8 +895,17 @@ def remove_model(old_ckpt_name): accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) - for step, batch in enumerate(train_dataloader): + skipped_dataloader = None + if initial_step > 0: + skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1) + initial_step = 1 + + for step, batch in enumerate(skipped_dataloader or train_dataloader): current_step.value = global_step + if initial_step > 0: + initial_step -= 1 + continue + with accelerator.accumulate(training_model): on_step_start(text_encoder, unet) @@ -881,7 +984,7 @@ def remove_model(old_ckpt_name): loss = train_util.conditional_loss( noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c ) - if args.masked_loss: + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -950,7 +1053,9 @@ def remove_model(old_ckpt_name): progress_bar.set_postfix(**{**max_mean_logs, **logs}) if args.logging_dir is not None: - logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) + logs = self.generate_step_logs( + args, current_loss, avr_loss, lr_scheduler, lr_descriptions, keys_scaled, mean_norm, maximum_norm + ) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: @@ -1101,6 +1206,28 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) + parser.add_argument( + "--skip_until_initial_step", + action="store_true", + help="skip training until initial_step is reached / initial_stepに到達するまで学習をスキップする", + ) + parser.add_argument( + "--initial_epoch", + type=int, + default=None, + help="initial epoch number, 1 means first epoch (same as not specifying). NOTE: initial_epoch/step doesn't affect to lr scheduler. Which means lr scheduler will start from 0 without `--resume`." + + " / 初期エポック数、1で最初のエポック(未指定時と同じ)。注意:initial_epoch/stepはlr schedulerに影響しないため、`--resume`しない場合はlr schedulerは0から始まる", + ) + parser.add_argument( + "--initial_step", + type=int, + default=None, + help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch." + + " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする", + ) + # parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio") + # parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") + # parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") return parser diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 10fce2677..ade077c36 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -510,7 +510,7 @@ def train(self, args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs ) # function for saving/removing @@ -589,7 +589,7 @@ def remove_model(old_ckpt_name): target = noise loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) - if args.masked_loss: + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index ddd03d532..efb59137b 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -407,7 +407,7 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs ) # function for saving/removing @@ -474,7 +474,7 @@ def remove_model(old_ckpt_name): target = noise loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) - if args.masked_loss: + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3])