From 1286e00bb0fc34c296f24b7057777f1c37cf8e11 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 18 Sep 2024 21:31:54 +0900 Subject: [PATCH] fix to call train/eval in schedulefree #1605 --- README.md | 3 +++ flux_train.py | 10 ++++++++++ library/train_util.py | 15 ++++++++++++++- train_network.py | 6 ++++++ 4 files changed, 33 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 034a260ff..843ae181b 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,9 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 18, 2024 (update 1): +Fixed an issue where train()/eval() was not called properly with the schedule-free optimizer. The schedule-free optimizer can be used in FLUX.1 LoRA training and fine-tuning for now. + Sep 18, 2024: - Schedule-free optimizer is added. Thanks to sdbds! See PR [#1600](https://github.com/kohya-ss/sd-scripts/pull/1600) for details. diff --git a/flux_train.py b/flux_train.py index 5d8326b1d..bc4e62793 100644 --- a/flux_train.py +++ b/flux_train.py @@ -347,8 +347,13 @@ def train(args): logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers") + if train_util.is_schedulefree_optimizer(optimizers[0], args): + raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers") + optimizer_train_fn = lambda: None # dummy function + optimizer_eval_fn = lambda: None # dummy function else: _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) # prepare dataloader # strategies are set here because they cannot be referenced in another process. Copy them with the dataset @@ -760,6 +765,7 @@ def optimizer_hook(parameter: torch.Tensor): progress_bar.update(1) global_step += 1 + optimizer_eval_fn() flux_train_utils.sample_images( accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs ) @@ -778,6 +784,7 @@ def optimizer_hook(parameter: torch.Tensor): global_step, accelerator.unwrap_model(flux), ) + optimizer_train_fn() current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if len(accelerator.trackers) > 0: @@ -800,6 +807,7 @@ def optimizer_hook(parameter: torch.Tensor): accelerator.wait_for_everyone() + optimizer_eval_fn() if args.save_every_n_epochs is not None: if accelerator.is_main_process: flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( @@ -816,12 +824,14 @@ def optimizer_hook(parameter: torch.Tensor): flux_train_utils.sample_images( accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs ) + optimizer_train_fn() is_main_process = accelerator.is_main_process # if is_main_process: flux = accelerator.unwrap_model(flux) accelerator.end_training() + optimizer_eval_fn() if args.save_state or args.save_state_on_train_end: train_util.save_state_on_train_end(args, accelerator) diff --git a/library/train_util.py b/library/train_util.py index a54f23ff6..fe9deb940 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -13,6 +13,7 @@ import time from typing import ( Any, + Callable, Dict, List, NamedTuple, @@ -4715,8 +4716,20 @@ def __instancecheck__(self, instance): return optimizer_name, optimizer_args, optimizer +def get_optimizer_train_eval_fn(optimizer: Optimizer, args: argparse.Namespace) -> Tuple[Callable, Callable]: + if not is_schedulefree_optimizer(optimizer, args): + # return dummy func + return lambda: None, lambda: None + + # get train and eval functions from optimizer + train_fn = optimizer.train + eval_fn = optimizer.eval + + return train_fn, eval_fn + + def is_schedulefree_optimizer(optimizer: Optimizer, args: argparse.Namespace) -> bool: - return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper + return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper def get_dummy_scheduler(optimizer: Optimizer) -> Any: diff --git a/train_network.py b/train_network.py index 34385ae08..55faa143e 100644 --- a/train_network.py +++ b/train_network.py @@ -498,6 +498,7 @@ def train(self, args): # accelerator.print(f"trainable_params: {k} = {v}") optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) + optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) # prepare dataloader # strategies are set here because they cannot be referenced in another process. Copy them with the dataset @@ -1199,6 +1200,7 @@ def remove_model(old_ckpt_name): progress_bar.update(1) global_step += 1 + optimizer_eval_fn() self.sample_images( accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet ) @@ -1217,6 +1219,7 @@ def remove_model(old_ckpt_name): if remove_step_no is not None: remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) remove_model(remove_ckpt_name) + optimizer_train_fn() current_loss = loss.detach().item() loss_recorder.add(epoch=epoch, step=step, loss=current_loss) @@ -1243,6 +1246,7 @@ def remove_model(old_ckpt_name): accelerator.wait_for_everyone() # 指定エポックごとにモデルを保存 + optimizer_eval_fn() if args.save_every_n_epochs is not None: saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs if is_main_process and saving: @@ -1258,6 +1262,7 @@ def remove_model(old_ckpt_name): train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) + optimizer_train_fn() # end of epoch @@ -1268,6 +1273,7 @@ def remove_model(old_ckpt_name): network = accelerator.unwrap_model(network) accelerator.end_training() + optimizer_eval_fn() if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator)