Skip to content

Commit

Permalink
comment out schedulefree wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Sep 17, 2024
1 parent ae2eaf9 commit 5ad328a
Showing 1 changed file with 18 additions and 19 deletions.
37 changes: 18 additions & 19 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3303,19 +3303,19 @@ def int_or_float(value):
help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")',
)

parser.add_argument(
"--optimizer_schedulefree_wrapper",
action="store_true",
help="use schedulefree_wrapper any optimizer / 任意のオプティマイザにschedulefree_wrapperを使用",
)
# parser.add_argument(
# "--optimizer_schedulefree_wrapper",
# action="store_true",
# help="use schedulefree_wrapper any optimizer / 任意のオプティマイザにschedulefree_wrapperを使用",
# )

parser.add_argument(
"--schedulefree_wrapper_args",
type=str,
default=None,
nargs="*",
help='additional arguments for schedulefree_wrapper (like "momentum=0.9 weight_decay_at_y=0.1 ...") / オプティマイザの追加引数(例: "momentum=0.9 weight_decay_at_y=0.1 ...")',
)
# parser.add_argument(
# "--schedulefree_wrapper_args",
# type=str,
# default=None,
# nargs="*",
# help='additional arguments for schedulefree_wrapper (like "momentum=0.9 weight_decay_at_y=0.1 ...") / オプティマイザの追加引数(例: "momentum=0.9 weight_decay_at_y=0.1 ...")',
# )

parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ")
parser.add_argument(
Expand Down Expand Up @@ -4375,8 +4375,6 @@ def get_optimizer(args, trainable_params):
optimizer_kwargs[key] = value
# logger.info(f"optkwargs {optimizer}_{kwargs}")

schedulefree_wrapper_kwargs = {}

lr = args.learning_rate
optimizer = None

Expand Down Expand Up @@ -4622,23 +4620,23 @@ def get_optimizer(args, trainable_params):

if "." not in case_sensitive_optimizer_type: # from torch.optim
optimizer_module = torch.optim
optimizer_class = getattr(optimizer_module, case_sensitive_optimizer_type)
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
else: # from other library
values = case_sensitive_optimizer_type.split(".")
optimizer_module = importlib.import_module(".".join(values[:-1]))
case_sensitive_optimizer_type = values[-1]

optimizer_class = getattr(optimizer_module, case_sensitive_optimizer_type)
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
optimizer_class = getattr(optimizer_module, case_sensitive_optimizer_type)
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

"""
# wrap any of above optimizer with schedulefree, if optimizer is not schedulefree
if args.optimizer_schedulefree_wrapper and not optimizer_type.endswith("schedulefree".lower()):
try:
import schedulefree as sf
except ImportError:
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
schedulefree_wrapper_kwargs = {}
if args.schedulefree_wrapper_args is not None and len(args.schedulefree_wrapper_args) > 0:
for arg in args.schedulefree_wrapper_args:
key, value = arg.split("=")
Expand Down Expand Up @@ -4709,6 +4707,7 @@ def __instancecheck__(self, instance):
optimizer = OptimizerProxy(sf_wrapper)
logger.info(f"wrap optimizer with ScheduleFreeWrapper | {schedulefree_wrapper_kwargs}")
"""

optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
Expand All @@ -4717,7 +4716,7 @@ def __instancecheck__(self, instance):


def is_schedulefree_optimizer(optimizer: Optimizer, args: argparse.Namespace) -> bool:
return args.optimizer_type.lower().endswith("schedulefree".lower()) or args.optimizer_schedulefree_wrapper
return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper


def get_dummy_scheduler(optimizer: Optimizer) -> Any:
Expand Down

0 comments on commit 5ad328a

Please sign in to comment.