From bd8d9845c6351d882b85137466a96d9086d2885a Mon Sep 17 00:00:00 2001 From: Sarthak Pati Date: Sun, 6 Aug 2023 10:08:22 -0400 Subject: [PATCH] updated defaults for sgd --- GANDLF/optimizers/wrap_torch.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/GANDLF/optimizers/wrap_torch.py b/GANDLF/optimizers/wrap_torch.py index fe91be096..9a0455ec5 100644 --- a/GANDLF/optimizers/wrap_torch.py +++ b/GANDLF/optimizers/wrap_torch.py @@ -28,10 +28,10 @@ def sgd(parameters): optimizer = SGD( parameters["model_parameters"], lr=parameters.get("learning_rate"), - momentum=parameters["optimizer"].get("momentum", 0.9), - weight_decay=parameters["optimizer"].get("weight_decay", 0), + momentum=parameters["optimizer"].get("momentum", 0.99), + weight_decay=parameters["optimizer"].get("weight_decay", 3e-05), dampening=parameters["optimizer"].get("dampening", 0), - nesterov=parameters["optimizer"].get("Nesterov", False), + nesterov=parameters["optimizer"].get("nesterov", True), ) return optimizer @@ -55,7 +55,7 @@ def asgd(parameters): alpha=parameters["optimizer"].get("alpha", 0.75), t0=parameters["optimizer"].get("t0", 1e6), lambd=parameters["optimizer"].get("lambd", 1e-4), - weight_decay=parameters["optimizer"].get("weight_decay", 0), + weight_decay=parameters["optimizer"].get("weight_decay", 3e-05), ) @@ -177,7 +177,7 @@ def adadelta(parameters): lr=parameters.get("learning_rate"), rho=parameters["optimizer"].get("rho", 0.9), eps=parameters["optimizer"].get("eps", 1e-6), - weight_decay=parameters["optimizer"].get("weight_decay", 0), + weight_decay=parameters["optimizer"].get("weight_decay", 3e-05), ) @@ -199,7 +199,7 @@ def adagrad(parameters): lr=parameters.get("learning_rate"), lr_decay=parameters["optimizer"].get("lr_decay", 0), eps=parameters["optimizer"].get("eps", 1e-6), - weight_decay=parameters["optimizer"].get("weight_decay", 0), + weight_decay=parameters["optimizer"].get("weight_decay", 3e-05), ) @@ -222,7 +222,7 @@ def rmsprop(parameters): eps=parameters["optimizer"].get("eps", 1e-8), centered=parameters["optimizer"].get("centered", False), momentum=parameters["optimizer"].get("momentum", 0), - weight_decay=parameters["optimizer"].get("weight_decay", 0), + weight_decay=parameters["optimizer"].get("weight_decay", 3e-05), ) @@ -242,6 +242,6 @@ def radam(parameters): lr=parameters.get("learning_rate"), betas=parameters["optimizer"].get("betas", (0.9, 0.999)), eps=parameters["optimizer"].get("eps", 1e-8), - weight_decay=parameters["optimizer"].get("weight_decay", 0), + weight_decay=parameters["optimizer"].get("weight_decay", 3e-05), foreach=parameters["optimizer"].get("foreach", None), )