Skip to content

Commit

Permalink
sd3 schedule free opt (#1605)
Browse files Browse the repository at this point in the history
* New ScheduleFree support for Flux (#1600)

* init

* use no schedule

* fix typo

* update for eval()

* fix typo

* update

* Update train_util.py

* Update requirements.txt

* update sfwrapper WIP

* no need to check schedulefree optimizer

* remove debug print

* comment out schedulefree wrapper

* update readme

---------

Co-authored-by: 青龍聖者@bdsqlsz <[email protected]>
  • Loading branch information
kohya-ss and sdbds committed Sep 17, 2024
1 parent a2ad7e5 commit bbd160b
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 7 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ The command to install PyTorch is as follows:

### Recent Updates

Sep 18, 2024:

- Schedule-free optimizer is added. Thanks to sdbds! See PR [#1600](https://github.com/kohya-ss/sd-scripts/pull/1600) for details.
- `schedulefree` is added to the dependencies. Please update the library if necessary.
- AdamWScheduleFree or SGDScheduleFree can be used. Specify `adamwschedulefree` or `sgdschedulefree` in `--optimizer_type`.
- Wrapper classes are not available for now.
- These can be used not only for FLUX.1 training but also for other training scripts after merging to the dev/main branch.

Sep 16, 2024:

Added `train_double_block_indices` and `train_double_block_indices` to the LoRA training script to specify the indices of the blocks to train. See [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training) for details.
Expand Down
152 changes: 145 additions & 7 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3303,6 +3303,20 @@ def int_or_float(value):
help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")',
)

# parser.add_argument(
# "--optimizer_schedulefree_wrapper",
# action="store_true",
# help="use schedulefree_wrapper any optimizer / 任意のオプティマイザにschedulefree_wrapperを使用",
# )

# parser.add_argument(
# "--schedulefree_wrapper_args",
# type=str,
# default=None,
# nargs="*",
# help='additional arguments for schedulefree_wrapper (like "momentum=0.9 weight_decay_at_y=0.1 ...") / オプティマイザの追加引数(例: "momentum=0.9 weight_decay_at_y=0.1 ...")',
# )

parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ")
parser.add_argument(
"--lr_scheduler_args",
Expand Down Expand Up @@ -4582,26 +4596,146 @@ def get_optimizer(args, trainable_params):
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

elif optimizer_type.endswith("schedulefree".lower()):
try:
import schedulefree as sf
except ImportError:
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
if optimizer_type == "AdamWScheduleFree".lower():
optimizer_class = sf.AdamWScheduleFree
logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}")
elif optimizer_type == "SGDScheduleFree".lower():
optimizer_class = sf.SGDScheduleFree
logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}")
else:
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
# make optimizer as train mode: we don't need to call train again, because eval will not be called in training loop
optimizer.train()

if optimizer is None:
# 任意のoptimizerを使う
optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
logger.info(f"use {optimizer_type} | {optimizer_kwargs}")
if "." not in optimizer_type:
case_sensitive_optimizer_type = args.optimizer_type # not lower
logger.info(f"use {case_sensitive_optimizer_type} | {optimizer_kwargs}")

if "." not in case_sensitive_optimizer_type: # from torch.optim
optimizer_module = torch.optim
else:
values = optimizer_type.split(".")
else: # from other library
values = case_sensitive_optimizer_type.split(".")
optimizer_module = importlib.import_module(".".join(values[:-1]))
optimizer_type = values[-1]
case_sensitive_optimizer_type = values[-1]

optimizer_class = getattr(optimizer_module, optimizer_type)
optimizer_class = getattr(optimizer_module, case_sensitive_optimizer_type)
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

"""
# wrap any of above optimizer with schedulefree, if optimizer is not schedulefree
if args.optimizer_schedulefree_wrapper and not optimizer_type.endswith("schedulefree".lower()):
try:
import schedulefree as sf
except ImportError:
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
schedulefree_wrapper_kwargs = {}
if args.schedulefree_wrapper_args is not None and len(args.schedulefree_wrapper_args) > 0:
for arg in args.schedulefree_wrapper_args:
key, value = arg.split("=")
value = ast.literal_eval(value)
schedulefree_wrapper_kwargs[key] = value
sf_wrapper = sf.ScheduleFreeWrapper(optimizer, **schedulefree_wrapper_kwargs)
sf_wrapper.train() # make optimizer as train mode
# we need to make optimizer as a subclass of torch.optim.Optimizer, we make another Proxy class over SFWrapper
class OptimizerProxy(torch.optim.Optimizer):
def __init__(self, sf_wrapper):
self._sf_wrapper = sf_wrapper
def __getattr__(self, name):
return getattr(self._sf_wrapper, name)
# override properties
@property
def state(self):
return self._sf_wrapper.state
@state.setter
def state(self, state):
self._sf_wrapper.state = state
@property
def param_groups(self):
return self._sf_wrapper.param_groups
@param_groups.setter
def param_groups(self, param_groups):
self._sf_wrapper.param_groups = param_groups
@property
def defaults(self):
return self._sf_wrapper.defaults
@defaults.setter
def defaults(self, defaults):
self._sf_wrapper.defaults = defaults
def add_param_group(self, param_group):
self._sf_wrapper.add_param_group(param_group)
def load_state_dict(self, state_dict):
self._sf_wrapper.load_state_dict(state_dict)
def state_dict(self):
return self._sf_wrapper.state_dict()
def zero_grad(self):
self._sf_wrapper.zero_grad()
def step(self, closure=None):
self._sf_wrapper.step(closure)
def train(self):
self._sf_wrapper.train()
def eval(self):
self._sf_wrapper.eval()
# isinstance チェックをパスするためのメソッド
def __instancecheck__(self, instance):
return isinstance(instance, (type(self), Optimizer))
optimizer = OptimizerProxy(sf_wrapper)
logger.info(f"wrap optimizer with ScheduleFreeWrapper | {schedulefree_wrapper_kwargs}")
"""

optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])

return optimizer_name, optimizer_args, optimizer


def is_schedulefree_optimizer(optimizer: Optimizer, args: argparse.Namespace) -> bool:
return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper


def get_dummy_scheduler(optimizer: Optimizer) -> Any:
# dummy scheduler for schedulefree optimizer. supports only empty step(), get_last_lr() and optimizers.
# this scheduler is used for logging only.
# this isn't be wrapped by accelerator because of this class is not a subclass of torch.optim.lr_scheduler._LRScheduler
class DummyScheduler:
def __init__(self, optimizer: Optimizer):
self.optimizer = optimizer

def step(self):
pass

def get_last_lr(self):
return [group["lr"] for group in self.optimizer.param_groups]

return DummyScheduler(optimizer)


# Modified version of get_scheduler() function from diffusers.optimizer.get_scheduler
# Add some checking and features to the original function.

Expand All @@ -4610,6 +4744,10 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
"""
Unified API to get any scheduler from its name.
"""
# if schedulefree optimizer, return dummy scheduler
if is_schedulefree_optimizer(optimizer, args):
return get_dummy_scheduler(optimizer)

name = args.lr_scheduler
num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
num_warmup_steps: Optional[int] = (
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pytorch-lightning==1.9.0
bitsandbytes==0.43.3
prodigyopt==1.0
lion-pytorch==0.0.6
schedulefree==1.2.7
tensorboard
safetensors==0.4.4
# gradio==3.16.2
Expand Down

0 comments on commit bbd160b

Please sign in to comment.