Skip to content

Commit

Permalink
updated defaults for sgd
Browse files Browse the repository at this point in the history
  • Loading branch information
sarthakpati authored Aug 6, 2023
1 parent c2040ea commit bd8d984
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions GANDLF/optimizers/wrap_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ def sgd(parameters):
optimizer = SGD(
parameters["model_parameters"],
lr=parameters.get("learning_rate"),
momentum=parameters["optimizer"].get("momentum", 0.9),
weight_decay=parameters["optimizer"].get("weight_decay", 0),
momentum=parameters["optimizer"].get("momentum", 0.99),
weight_decay=parameters["optimizer"].get("weight_decay", 3e-05),
dampening=parameters["optimizer"].get("dampening", 0),
nesterov=parameters["optimizer"].get("Nesterov", False),
nesterov=parameters["optimizer"].get("nesterov", True),
)

return optimizer
Expand All @@ -55,7 +55,7 @@ def asgd(parameters):
alpha=parameters["optimizer"].get("alpha", 0.75),
t0=parameters["optimizer"].get("t0", 1e6),
lambd=parameters["optimizer"].get("lambd", 1e-4),
weight_decay=parameters["optimizer"].get("weight_decay", 0),
weight_decay=parameters["optimizer"].get("weight_decay", 3e-05),
)


Expand Down Expand Up @@ -177,7 +177,7 @@ def adadelta(parameters):
lr=parameters.get("learning_rate"),
rho=parameters["optimizer"].get("rho", 0.9),
eps=parameters["optimizer"].get("eps", 1e-6),
weight_decay=parameters["optimizer"].get("weight_decay", 0),
weight_decay=parameters["optimizer"].get("weight_decay", 3e-05),
)


Expand All @@ -199,7 +199,7 @@ def adagrad(parameters):
lr=parameters.get("learning_rate"),
lr_decay=parameters["optimizer"].get("lr_decay", 0),
eps=parameters["optimizer"].get("eps", 1e-6),
weight_decay=parameters["optimizer"].get("weight_decay", 0),
weight_decay=parameters["optimizer"].get("weight_decay", 3e-05),
)


Expand All @@ -222,7 +222,7 @@ def rmsprop(parameters):
eps=parameters["optimizer"].get("eps", 1e-8),
centered=parameters["optimizer"].get("centered", False),
momentum=parameters["optimizer"].get("momentum", 0),
weight_decay=parameters["optimizer"].get("weight_decay", 0),
weight_decay=parameters["optimizer"].get("weight_decay", 3e-05),
)


Expand All @@ -242,6 +242,6 @@ def radam(parameters):
lr=parameters.get("learning_rate"),
betas=parameters["optimizer"].get("betas", (0.9, 0.999)),
eps=parameters["optimizer"].get("eps", 1e-8),
weight_decay=parameters["optimizer"].get("weight_decay", 0),
weight_decay=parameters["optimizer"].get("weight_decay", 3e-05),
foreach=parameters["optimizer"].get("foreach", None),
)

0 comments on commit bd8d984

Please sign in to comment.