diff --git a/fine_tune.py b/fine_tune.py index fb6b3ed69..9d586e35f 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -272,18 +272,31 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder) else: ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet) - ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - ds_model, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + ds_model, optimizer, train_dataloader = accelerator.prepare( + ds_model, optimizer, train_dataloader + ) + else: + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) training_models = [ds_model] else: # acceleratorがなんかよろしくやってくれるらしい if args.train_text_encoder: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + unet, text_encoder, optimizer, train_dataloader = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader + ) + else: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: @@ -350,6 +363,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): m.train() for step, batch in enumerate(train_dataloader): + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer.train() current_step.value = global_step with accelerator.accumulate(*training_models): with torch.no_grad(): @@ -425,9 +440,13 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not (args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/flux_train.py b/flux_train.py index 33481df8f..2e8de0ee8 100644 --- a/flux_train.py +++ b/flux_train.py @@ -416,9 +416,14 @@ def train(args): if args.deepspeed: ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=flux) # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 - ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - ds_model, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + ds_model, optimizer, train_dataloader = accelerator.prepare( + ds_model, optimizer, train_dataloader + ) + else: + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) training_models = [ds_model] else: @@ -427,7 +432,10 @@ def train(args): flux = accelerator.prepare(flux, device_placement=[not is_swapping_blocks]) if is_swapping_blocks: accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage - optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer, train_dataloader = accelerator.prepare(optimizer, train_dataloader) + else: + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: @@ -643,6 +651,8 @@ def optimizer_hook(parameter: torch.Tensor): m.train() for step, batch in enumerate(train_dataloader): + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer.train() current_step.value = global_step if args.blockwise_fused_optimizers: @@ -746,15 +756,20 @@ def optimizer_hook(parameter: torch.Tensor): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not (args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) else: # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook - lr_scheduler.step() + if not args.optimizer_schedulefree_wrapper: + lr_scheduler.step() if args.blockwise_fused_optimizers: for i in range(1, len(optimizers)): lr_schedulers[i].step() + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/library/train_util.py b/library/train_util.py index 60afd4219..c9a162893 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3303,6 +3303,20 @@ def int_or_float(value): help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")', ) + parser.add_argument( + "--optimizer_schedulefree_wrapper", + action="store_true", + help="use schedulefree_wrapper any optimizer / 任意のオプティマイザにschedulefree_wrapperを使用", + ) + + parser.add_argument( + "--schedulefree_wrapper_args", + type=str, + default=None, + nargs="*", + help='additional arguments for schedulefree_wrapper (like "momentum=0.9 weight_decay_at_y=0.1 ...") / オプティマイザの追加引数(例: "momentum=0.9 weight_decay_at_y=0.1 ...")', + ) + parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ") parser.add_argument( "--lr_scheduler_args", @@ -4361,6 +4375,8 @@ def get_optimizer(args, trainable_params): optimizer_kwargs[key] = value # logger.info(f"optkwargs {optimizer}_{kwargs}") + schedulefree_wrapper_kwargs = {} + lr = args.learning_rate optimizer = None @@ -4581,6 +4597,21 @@ def get_optimizer(args, trainable_params): logger.info(f"use AdamW optimizer | {optimizer_kwargs}") optimizer_class = torch.optim.AdamW optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type.endswith("schedulefree".lower()): + try: + import schedulefree as sf + except ImportError: + raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") + if optimizer_type == "AdamWScheduleFree".lower(): + optimizer_class = sf.AdamWScheduleFree + logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}") + elif optimizer_type == "SGDScheduleFree".lower(): + optimizer_class = sf.SGDScheduleFree + logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}") + else: + raise ValueError(f"Unknown optimizer type: {optimizer_type}") + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) if optimizer is None: # 任意のoptimizerを使う @@ -4588,13 +4619,27 @@ def get_optimizer(args, trainable_params): logger.info(f"use {optimizer_type} | {optimizer_kwargs}") if "." not in optimizer_type: optimizer_module = torch.optim + optimizer_class = getattr(optimizer_module, optimizer_type) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + if args.optimizer_schedulefree_wrapper and not optimizer_type.endswith("schedulefree"): + try: + import schedulefree as sf + except ImportError: + raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") + + if args.schedulefree_wrapper_args is not None and len(args.schedulefree_wrapper_args) > 0: + for arg in args.schedulefree_wrapper_args: + key, value = arg.split("=") + value = ast.literal_eval(value) + schedulefree_wrapper_kwargs[key] = value + optimizer = sf.ScheduleFreeWrapper(optimizer, **schedulefree_wrapper_kwargs) else: values = optimizer_type.split(".") optimizer_module = importlib.import_module(".".join(values[:-1])) optimizer_type = values[-1] - optimizer_class = getattr(optimizer_module, optimizer_type) - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + optimizer_class = getattr(optimizer_module, optimizer_type) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) diff --git a/requirements.txt b/requirements.txt index 9a4fa0c15..bab53f20f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ pytorch-lightning==1.9.0 bitsandbytes==0.43.3 prodigyopt==1.0 lion-pytorch==0.0.6 +schedulefree==1.2.7 tensorboard safetensors==0.4.4 # gradio==3.16.2 diff --git a/sdxl_train.py b/sdxl_train.py index 7291ddd2f..35e54a4a8 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -484,9 +484,14 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): text_encoder2=text_encoder2 if train_text_encoder2 else None, ) # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 - ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - ds_model, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + ds_model, optimizer, train_dataloader = accelerator.prepare( + ds_model, optimizer, train_dataloader + ) + else: + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) training_models = [ds_model] else: @@ -497,7 +502,10 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): text_encoder1 = accelerator.prepare(text_encoder1) if train_text_encoder2: text_encoder2 = accelerator.prepare(text_encoder2) - optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer, train_dataloader = accelerator.prepare(optimizer, train_dataloader) + else: + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # TextEncoderの出力をキャッシュするときにはCPUへ移動する if args.cache_text_encoder_outputs: @@ -630,6 +638,8 @@ def optimizer_hook(parameter: torch.Tensor): m.train() for step, batch in enumerate(train_dataloader): + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer.train() current_step.value = global_step if args.fused_optimizer_groups: @@ -749,15 +759,20 @@ def optimizer_hook(parameter: torch.Tensor): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not (args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) else: # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook - lr_scheduler.step() + if not args.optimizer_schedulefree_wrapper: + lr_scheduler.step() if args.fused_optimizer_groups: for i in range(1, len(optimizers)): lr_schedulers[i].step() + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 9d1cfc63e..e5f5bd782 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -307,14 +307,22 @@ def train(args): unet.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) if isinstance(unet, DDP): unet._set_static_graph() # avoid error for multiple use of the parameter if args.gradient_checkpointing: + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer.train() unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる + else: + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer.eval() unet.eval() # TextEncoderの出力をキャッシュするときにはCPUへ移動する @@ -416,6 +424,8 @@ def remove_model(old_ckpt_name): current_epoch.value = epoch + 1 for step, batch in enumerate(train_dataloader): + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer.train() current_step.value = global_step with accelerator.accumulate(unet): with torch.no_grad(): @@ -510,9 +520,13 @@ def remove_model(old_ckpt_name): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not (args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 6fa1d6096..dc49636c4 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -254,15 +254,24 @@ def train(args): network.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい - unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, network, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + unet, network, optimizer, train_dataloader = accelerator.prepare( + unet, network, optimizer, train_dataloader + ) + else: + unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, network, optimizer, train_dataloader, lr_scheduler + ) network: control_net_lllite.ControlNetLLLite if args.gradient_checkpointing: unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer.train() else: unet.eval() + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer.eval() network.prepare_grad_etc() @@ -357,6 +366,8 @@ def remove_model(old_ckpt_name): network.on_epoch_start() # train() for step, batch in enumerate(train_dataloader): + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer.train() current_step.value = global_step with accelerator.accumulate(network): with torch.no_grad(): @@ -449,9 +460,13 @@ def remove_model(old_ckpt_name): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not (args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/train_controlnet.py b/train_controlnet.py index 57f0d263f..c09fc7e49 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -298,9 +298,14 @@ def __contains__(self, name): controlnet.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい - controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - controlnet, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + controlnet, optimizer, train_dataloader = accelerator.prepare( + controlnet, optimizer, train_dataloader + ) + else: + controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + controlnet, optimizer, train_dataloader, lr_scheduler + ) unet.requires_grad_(False) text_encoder.requires_grad_(False) @@ -420,6 +425,8 @@ def remove_model(old_ckpt_name): current_epoch.value = epoch + 1 for step, batch in enumerate(train_dataloader): + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer.train() current_step.value = global_step with accelerator.accumulate(controlnet): with torch.no_grad(): @@ -500,9 +507,13 @@ def remove_model(old_ckpt_name): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not (args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/train_db.py b/train_db.py index d42afd89a..b9871d9cf 100644 --- a/train_db.py +++ b/train_db.py @@ -244,19 +244,32 @@ def train(args): ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder) else: ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet) - ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - ds_model, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + ds_model, optimizer, train_dataloader = accelerator.prepare( + ds_model, optimizer, train_dataloader + ) + else: + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) training_models = [ds_model] else: if train_text_encoder: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + unet, text_encoder, optimizer, train_dataloader = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader + ) + else: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) training_models = [unet, text_encoder] else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) training_models = [unet] if not train_text_encoder: @@ -331,6 +344,8 @@ def train(args): text_encoder.train() for step, batch in enumerate(train_dataloader): + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer.train() current_step.value = global_step # 指定したステップ数でText Encoderの学習を止める if global_step == args.stop_text_encoder_training: @@ -414,9 +429,13 @@ def train(args): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not (args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/train_network.py b/train_network.py index 34385ae08..ac1daedcd 100644 --- a/train_network.py +++ b/train_network.py @@ -587,9 +587,14 @@ def train(self, args): text_encoder2=(text_encoders[1] if flags[1] else None) if len(text_encoders) > 1 else None, network=network, ) - ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - ds_model, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + ds_model, optimizer, train_dataloader = accelerator.prepare( + ds_model, optimizer, train_dataloader + ) + else: + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) training_model = ds_model else: if train_unet: @@ -607,15 +612,23 @@ def train(self, args): text_encoder = text_encoders[0] else: pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set - - network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - network, optimizer, train_dataloader, lr_scheduler - ) + + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + network, optimizer, train_dataloader = accelerator.prepare( + network, optimizer, train_dataloader + ) + else: + network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + network, optimizer, train_dataloader, lr_scheduler + ) training_model = network if args.gradient_checkpointing: # according to TI example in Diffusers, train is required + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer.train() unet.train() + for i, (t_enc, frag) in enumerate(zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders))): t_enc.train() @@ -624,6 +637,8 @@ def train(self, args): self.prepare_text_encoder_grad_ckpt_workaround(i, t_enc) else: + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer.eval() unet.eval() for t_enc in text_encoders: t_enc.eval() @@ -1074,6 +1089,8 @@ def remove_model(old_ckpt_name): initial_step = 1 for step, batch in enumerate(skipped_dataloader or train_dataloader): + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer.train() current_step.value = global_step if initial_step > 0: initial_step -= 1 @@ -1183,7 +1200,8 @@ def remove_model(old_ckpt_name): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not (args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) if args.scale_weight_norms: @@ -1194,6 +1212,9 @@ def remove_model(old_ckpt_name): else: keys_scaled, mean_norm, maximum_norm = None, None, None + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 956c78603..48abdcc3d 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -452,8 +452,12 @@ def train(self, args): unet.to(accelerator.device, dtype=weight_dtype) if args.gradient_checkpointing: # according to TI example in Diffusers, train is required # TODO U-Netをオリジナルに置き換えたのでいらないはずなので、後で確認して消す + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer.train() unet.train() else: + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer.eval() unet.eval() text_encoding_strategy = self.get_text_encoding_strategy(args) @@ -565,6 +569,8 @@ def remove_model(old_ckpt_name): loss_total = 0 for step, batch in enumerate(train_dataloader): + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer.train() current_step.value = global_step with accelerator.accumulate(text_encoders[0]): with torch.no_grad(): @@ -628,7 +634,8 @@ def remove_model(old_ckpt_name): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not (args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) # Let's make sure we don't update any embedding weights besides the newly added token @@ -642,6 +649,9 @@ def remove_model(old_ckpt_name): index_no_updates ] + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index ca0b603fb..8ebb2a8c5 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -335,9 +335,14 @@ def train(args): lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # acceleratorがなんかよろしくやってくれるらしい - text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + text_encoder, optimizer, train_dataloader = accelerator.prepare( + text_encoder, optimizer, train_dataloader + ) + else: + text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, optimizer, train_dataloader, lr_scheduler + ) index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0] # logger.info(len(index_no_updates), torch.sum(index_no_updates)) @@ -354,8 +359,12 @@ def train(args): unet.to(accelerator.device, dtype=weight_dtype) if args.gradient_checkpointing: # according to TI example in Diffusers, train is required unet.train() + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer.train() else: unet.eval() + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer.eval() if not cache_latents: vae.requires_grad_(False) @@ -438,6 +447,8 @@ def remove_model(old_ckpt_name): loss_total = 0 for step, batch in enumerate(train_dataloader): + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer.train() current_step.value = global_step with accelerator.accumulate(text_encoder): with torch.no_grad(): @@ -496,7 +507,8 @@ def remove_model(old_ckpt_name): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not (args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) # Let's make sure we don't update any embedding weights besides the newly added token @@ -505,6 +517,9 @@ def remove_model(old_ckpt_name): index_no_updates ] + if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper: + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1)