Skip to content

Commit

Permalink
New ScheduleFree support for Flux (#1600)
Browse files Browse the repository at this point in the history
* init

* use no schedule

* fix typo

* update for eval()

* fix typo

* update

* Update train_util.py

* Update requirements.txt
  • Loading branch information
sdbds committed Sep 16, 2024
1 parent 96c677b commit e634975
Show file tree
Hide file tree
Showing 12 changed files with 253 additions and 53 deletions.
35 changes: 27 additions & 8 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,18 +272,31 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
else:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
ds_model, optimizer, train_dataloader = accelerator.prepare(
ds_model, optimizer, train_dataloader
)
else:
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]
else:
# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader
)
else:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)

# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
Expand Down Expand Up @@ -350,6 +363,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
m.train()

for step, batch in enumerate(train_dataloader):
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer.train()
current_step.value = global_step
with accelerator.accumulate(*training_models):
with torch.no_grad():
Expand Down Expand Up @@ -425,9 +440,13 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
if not (args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
27 changes: 21 additions & 6 deletions flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,9 +416,14 @@ def train(args):
if args.deepspeed:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=flux)
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
ds_model, optimizer, train_dataloader = accelerator.prepare(
ds_model, optimizer, train_dataloader
)
else:
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]

else:
Expand All @@ -427,7 +432,10 @@ def train(args):
flux = accelerator.prepare(flux, device_placement=[not is_swapping_blocks])
if is_swapping_blocks:
accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer, train_dataloader = accelerator.prepare(optimizer, train_dataloader)
else:
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)

# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
Expand Down Expand Up @@ -643,6 +651,8 @@ def optimizer_hook(parameter: torch.Tensor):
m.train()

for step, batch in enumerate(train_dataloader):
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer.train()
current_step.value = global_step

if args.blockwise_fused_optimizers:
Expand Down Expand Up @@ -746,15 +756,20 @@ def optimizer_hook(parameter: torch.Tensor):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
if not (args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
else:
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
lr_scheduler.step()
if not args.optimizer_schedulefree_wrapper:
lr_scheduler.step()
if args.blockwise_fused_optimizers:
for i in range(1, len(optimizers)):
lr_schedulers[i].step()

if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
49 changes: 47 additions & 2 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3303,6 +3303,20 @@ def int_or_float(value):
help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")',
)

parser.add_argument(
"--optimizer_schedulefree_wrapper",
action="store_true",
help="use schedulefree_wrapper any optimizer / 任意のオプティマイザにschedulefree_wrapperを使用",
)

parser.add_argument(
"--schedulefree_wrapper_args",
type=str,
default=None,
nargs="*",
help='additional arguments for schedulefree_wrapper (like "momentum=0.9 weight_decay_at_y=0.1 ...") / オプティマイザの追加引数(例: "momentum=0.9 weight_decay_at_y=0.1 ...")',
)

parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ")
parser.add_argument(
"--lr_scheduler_args",
Expand Down Expand Up @@ -4361,6 +4375,8 @@ def get_optimizer(args, trainable_params):
optimizer_kwargs[key] = value
# logger.info(f"optkwargs {optimizer}_{kwargs}")

schedulefree_wrapper_kwargs = {}

lr = args.learning_rate
optimizer = None

Expand Down Expand Up @@ -4581,20 +4597,49 @@ def get_optimizer(args, trainable_params):
logger.info(f"use AdamW optimizer | {optimizer_kwargs}")
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

elif optimizer_type.endswith("schedulefree".lower()):
try:
import schedulefree as sf
except ImportError:
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
if optimizer_type == "AdamWScheduleFree".lower():
optimizer_class = sf.AdamWScheduleFree
logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}")
elif optimizer_type == "SGDScheduleFree".lower():
optimizer_class = sf.SGDScheduleFree
logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}")
else:
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

if optimizer is None:
# 任意のoptimizerを使う
optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
logger.info(f"use {optimizer_type} | {optimizer_kwargs}")
if "." not in optimizer_type:
optimizer_module = torch.optim
optimizer_class = getattr(optimizer_module, optimizer_type)
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
if args.optimizer_schedulefree_wrapper and not optimizer_type.endswith("schedulefree"):
try:
import schedulefree as sf
except ImportError:
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")

if args.schedulefree_wrapper_args is not None and len(args.schedulefree_wrapper_args) > 0:
for arg in args.schedulefree_wrapper_args:
key, value = arg.split("=")
value = ast.literal_eval(value)
schedulefree_wrapper_kwargs[key] = value
optimizer = sf.ScheduleFreeWrapper(optimizer, **schedulefree_wrapper_kwargs)
else:
values = optimizer_type.split(".")
optimizer_module = importlib.import_module(".".join(values[:-1]))
optimizer_type = values[-1]

optimizer_class = getattr(optimizer_module, optimizer_type)
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
optimizer_class = getattr(optimizer_module, optimizer_type)
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pytorch-lightning==1.9.0
bitsandbytes==0.43.3
prodigyopt==1.0
lion-pytorch==0.0.6
schedulefree==1.2.7
tensorboard
safetensors==0.4.4
# gradio==3.16.2
Expand Down
27 changes: 21 additions & 6 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,9 +484,14 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder2=text_encoder2 if train_text_encoder2 else None,
)
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
ds_model, optimizer, train_dataloader = accelerator.prepare(
ds_model, optimizer, train_dataloader
)
else:
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]

else:
Expand All @@ -497,7 +502,10 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder1 = accelerator.prepare(text_encoder1)
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer, train_dataloader = accelerator.prepare(optimizer, train_dataloader)
else:
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)

# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
Expand Down Expand Up @@ -630,6 +638,8 @@ def optimizer_hook(parameter: torch.Tensor):
m.train()

for step, batch in enumerate(train_dataloader):
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer.train()
current_step.value = global_step

if args.fused_optimizer_groups:
Expand Down Expand Up @@ -749,15 +759,20 @@ def optimizer_hook(parameter: torch.Tensor):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
if not (args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
else:
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
lr_scheduler.step()
if not args.optimizer_schedulefree_wrapper:
lr_scheduler.step()
if args.fused_optimizer_groups:
for i in range(1, len(optimizers)):
lr_schedulers[i].step()

if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
18 changes: 16 additions & 2 deletions sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,14 +307,22 @@ def train(args):
unet.to(weight_dtype)

# acceleratorがなんかよろしくやってくれるらしい
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)

if isinstance(unet, DDP):
unet._set_static_graph() # avoid error for multiple use of the parameter

if args.gradient_checkpointing:
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer.train()
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる

else:
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer.eval()
unet.eval()

# TextEncoderの出力をキャッシュするときにはCPUへ移動する
Expand Down Expand Up @@ -416,6 +424,8 @@ def remove_model(old_ckpt_name):
current_epoch.value = epoch + 1

for step, batch in enumerate(train_dataloader):
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer.train()
current_step.value = global_step
with accelerator.accumulate(unet):
with torch.no_grad():
Expand Down Expand Up @@ -510,9 +520,13 @@ def remove_model(old_ckpt_name):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
if not (args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
23 changes: 19 additions & 4 deletions sdxl_train_control_net_lllite_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,15 +254,24 @@ def train(args):
network.to(weight_dtype)

# acceleratorがなんかよろしくやってくれるらしい
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, network, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
unet, network, optimizer, train_dataloader = accelerator.prepare(
unet, network, optimizer, train_dataloader
)
else:
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, network, optimizer, train_dataloader, lr_scheduler
)
network: control_net_lllite.ControlNetLLLite

if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer.train()
else:
unet.eval()
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer.eval()

network.prepare_grad_etc()

Expand Down Expand Up @@ -357,6 +366,8 @@ def remove_model(old_ckpt_name):
network.on_epoch_start() # train()

for step, batch in enumerate(train_dataloader):
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer.train()
current_step.value = global_step
with accelerator.accumulate(network):
with torch.no_grad():
Expand Down Expand Up @@ -449,9 +460,13 @@ def remove_model(old_ckpt_name):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
if not (args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
Loading

0 comments on commit e634975

Please sign in to comment.