From 5ad328a7703336a6f95c47d5c1fd55ba7b5ae38c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 18 Sep 2024 07:44:26 +0900 Subject: [PATCH] comment out schedulefree wrapper --- library/train_util.py | 37 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index d92006a15..a54f23ff6 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3303,19 +3303,19 @@ def int_or_float(value): help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")', ) - parser.add_argument( - "--optimizer_schedulefree_wrapper", - action="store_true", - help="use schedulefree_wrapper any optimizer / 任意のオプティマイザにschedulefree_wrapperを使用", - ) + # parser.add_argument( + # "--optimizer_schedulefree_wrapper", + # action="store_true", + # help="use schedulefree_wrapper any optimizer / 任意のオプティマイザにschedulefree_wrapperを使用", + # ) - parser.add_argument( - "--schedulefree_wrapper_args", - type=str, - default=None, - nargs="*", - help='additional arguments for schedulefree_wrapper (like "momentum=0.9 weight_decay_at_y=0.1 ...") / オプティマイザの追加引数(例: "momentum=0.9 weight_decay_at_y=0.1 ...")', - ) + # parser.add_argument( + # "--schedulefree_wrapper_args", + # type=str, + # default=None, + # nargs="*", + # help='additional arguments for schedulefree_wrapper (like "momentum=0.9 weight_decay_at_y=0.1 ...") / オプティマイザの追加引数(例: "momentum=0.9 weight_decay_at_y=0.1 ...")', + # ) parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ") parser.add_argument( @@ -4375,8 +4375,6 @@ def get_optimizer(args, trainable_params): optimizer_kwargs[key] = value # logger.info(f"optkwargs {optimizer}_{kwargs}") - schedulefree_wrapper_kwargs = {} - lr = args.learning_rate optimizer = None @@ -4622,16 +4620,15 @@ def get_optimizer(args, trainable_params): if "." not in case_sensitive_optimizer_type: # from torch.optim optimizer_module = torch.optim - optimizer_class = getattr(optimizer_module, case_sensitive_optimizer_type) - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) else: # from other library values = case_sensitive_optimizer_type.split(".") optimizer_module = importlib.import_module(".".join(values[:-1])) case_sensitive_optimizer_type = values[-1] - optimizer_class = getattr(optimizer_module, case_sensitive_optimizer_type) - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + optimizer_class = getattr(optimizer_module, case_sensitive_optimizer_type) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + """ # wrap any of above optimizer with schedulefree, if optimizer is not schedulefree if args.optimizer_schedulefree_wrapper and not optimizer_type.endswith("schedulefree".lower()): try: @@ -4639,6 +4636,7 @@ def get_optimizer(args, trainable_params): except ImportError: raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") + schedulefree_wrapper_kwargs = {} if args.schedulefree_wrapper_args is not None and len(args.schedulefree_wrapper_args) > 0: for arg in args.schedulefree_wrapper_args: key, value = arg.split("=") @@ -4709,6 +4707,7 @@ def __instancecheck__(self, instance): optimizer = OptimizerProxy(sf_wrapper) logger.info(f"wrap optimizer with ScheduleFreeWrapper | {schedulefree_wrapper_kwargs}") + """ optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) @@ -4717,7 +4716,7 @@ def __instancecheck__(self, instance): def is_schedulefree_optimizer(optimizer: Optimizer, args: argparse.Namespace) -> bool: - return args.optimizer_type.lower().endswith("schedulefree".lower()) or args.optimizer_schedulefree_wrapper + return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper def get_dummy_scheduler(optimizer: Optimizer) -> Any: