diff --git a/GANDLF/optimizers/__init__.py b/GANDLF/optimizers/__init__.py index 97de43fa1..b59afb22f 100644 --- a/GANDLF/optimizers/__init__.py +++ b/GANDLF/optimizers/__init__.py @@ -15,6 +15,8 @@ from .wrap_monai import novograd_wrapper +from .ademamix import ademamix_wrapper + global_optimizer_dict = { "sgd": sgd, "asgd": asgd, @@ -29,6 +31,7 @@ "radam": radam, "novograd": novograd_wrapper, "nadam": nadam, + "ademamix": ademamix_wrapper, } diff --git a/GANDLF/optimizers/ademamix.py b/GANDLF/optimizers/ademamix.py new file mode 100644 index 000000000..63f68d9f9 --- /dev/null +++ b/GANDLF/optimizers/ademamix.py @@ -0,0 +1,204 @@ +import math +from typing import Callable, Iterable, List, Optional, Tuple + +import torch +from torch import Tensor +from torch.optim import Optimizer + + +class AdEMAMix(Optimizer): + r"""Adapted from https://github.com/frgfm/Holocron/blob/main/holocron/optim/ademamix.py + + Implements the AdEMAMix optimizer from `"The AdEMAMix Optimizer: Better, Faster, Older" `_. + + The estimation of momentums is described as follows, :math:`\forall t \geq 1`: + + .. math:: + m_{1,t} \leftarrow \beta_1 m_{1, t-1} + (1 - \beta_1) g_t \\ + m_{2,t} \leftarrow \beta_3 m_{2, t-1} + (1 - \beta_3) g_t \\ + s_t \leftarrow \beta_2 s_{t-1} + (1 - \beta_2) (g_t - m_t)^2 + \epsilon + + where :math:`g_t` is the gradient of :math:`\theta_t`, + :math:`\beta_1, \beta_2, \beta_3 \in [0, 1]^3` are the exponential average smoothing coefficients, + :math:`m_{1,0} = 0,\ m_{2,0} = 0,\ s_0 = 0`, :math:`\epsilon > 0`. + + Then we correct their biases using: + + .. math:: + \hat{m_{1,t}} \leftarrow \frac{m_{1,t}}{1 - \beta_1^t} \\ + \hat{s_t} \leftarrow \frac{s_t}{1 - \beta_2^t} + + And finally the update step is performed using the following rule: + + .. math:: + \theta_t \leftarrow \theta_{t-1} - \eta \frac{\hat{m_{1,t}} + \alpha m_{2,t}}{\sqrt{\hat{s_t}} + \epsilon} + + where :math:`\theta_t` is the parameter value at step :math:`t` (:math:`\theta_0` being the initialization value), + :math:`\eta` is the learning rate, :math:`\alpha > 0` :math:`\epsilon > 0`. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups + lr (float, optional): learning rate + betas (Tuple[float, float, float], optional): coefficients used for running averages (default: (0.9, 0.999, 0.9999)) + alpha (float, optional): the exponential decay rate of the second moment estimates (default: 5.0) + eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (bool, optional): whether to use the AMSGrad variant (default: False) + """ + + def __init__( + self, + params: Iterable[torch.nn.Parameter], + lr: float = 1e-3, + betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999), + alpha: float = 5.0, + eps: float = 1e-8, + weight_decay: float = 0.0, + ) -> None: + assert lr >= 0.0, f"Invalid learning rate: {lr}" + assert eps >= 0.0, f"Invalid epsilon value: {eps}" + assert all( + 0.0 <= beta < 1.0 for beta in betas + ), f"Invalid beta parameters: {betas}" + defaults = { + "lr": lr, + "betas": betas, + "alpha": alpha, + "eps": eps, + "weight_decay": weight_decay, + } + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: # type: ignore[override] + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avgs_slow = [] + exp_avg_sqs = [] + state_steps = [] + + for p in group["params"]: + if p.grad is not None: + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError( + f"{self.__class__.__name__} does not support sparse gradients" + ) + grads.append(p.grad) + + state = self.state[p] + # Lazy state initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + state["exp_avg_slow"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avgs.append(state["exp_avg"]) + exp_avgs_slow.append(state["exp_avg_slow"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + + # update the steps for each param group update + state["step"] += 1 + # record the step after step update + state_steps.append(state["step"]) + + beta1, beta2, beta3 = group["betas"] + _update_ademamix( + params_with_grad, + grads, + exp_avgs, + exp_avgs_slow, + exp_avg_sqs, + state_steps, + beta1, + beta2, + beta3, + group["alpha"], + group["lr"], + group["weight_decay"], + group["eps"], + ) + return loss + + +def _update_ademamix( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avgs_slow: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[int], + beta1: float, + beta2: float, + beta3: float, + alpha: float, + lr: float, + weight_decay: float, + eps: float, +) -> None: + r"""Functional API that performs AdaBelief algorithm computation. + See :class:`~holocron.optim.AdaBelief` for details. + """ + for i, param in enumerate(params): + grad = grads[i] + m1 = exp_avgs[i] + m2 = exp_avgs_slow[i] + nu = exp_avg_sqs[i] + step = state_steps[i] + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + if weight_decay != 0: + grad = grad.add(param, alpha=weight_decay) + + # Decay the first and second moment running average coefficient + m1.mul_(beta1).add_(grad, alpha=1 - beta1) + nu.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + m2.mul_(beta3).add_(grad, alpha=1 - beta3) + + denom = (nu.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + param.addcdiv_(m1 / bias_correction1 + alpha * m2, denom, value=-lr) + + +def ademamix_wrapper(parameters: dict) -> torch.optim.Optimizer: + """ + Creates an AdEMAMix optimizer from the PyTorch `torch.optim` module using the input parameters. + + Args: + parameters (dict): A dictionary containing the input parameters for the optimizer. + + Returns: + torch.optim.Optimizer: An AdEMAMix optimizer. + """ + + return AdEMAMix( + params=parameters["model_parameters"], + lr=parameters.get("learning_rate", 1e-3), + betas=parameters.get("betas", (0.9, 0.999, 0.9999)), + alpha=parameters.get("alpha", 5.0), + eps=parameters.get("eps", 1e-8), + weight_decay=parameters.get("weight_decay", 0.0), + ) diff --git a/GANDLF/optimizers/wrap_monai.py b/GANDLF/optimizers/wrap_monai.py index 23745e4a5..221ba57bd 100644 --- a/GANDLF/optimizers/wrap_monai.py +++ b/GANDLF/optimizers/wrap_monai.py @@ -1,10 +1,11 @@ +import monai from monai.optimizers import Novograd -def novograd_wrapper(parameters): +def novograd_wrapper(parameters: dict) -> monai.optimizers.Novograd: return Novograd( parameters["model_parameters"], - lr=parameters.get("learning_rate"), + lr=parameters.get("learning_rate", 1e-3), betas=parameters["optimizer"].get("betas", (0.9, 0.999)), eps=parameters["optimizer"].get("eps", 1e-8), weight_decay=parameters["optimizer"].get("weight_decay", 3e-05), diff --git a/GANDLF/optimizers/wrap_torch.py b/GANDLF/optimizers/wrap_torch.py index 9852f7973..2f4650bdb 100644 --- a/GANDLF/optimizers/wrap_torch.py +++ b/GANDLF/optimizers/wrap_torch.py @@ -1,3 +1,4 @@ +import torch from torch.optim import ( SGD, ASGD, @@ -14,7 +15,7 @@ ) -def sgd(parameters): +def sgd(parameters: dict) -> torch.optim.SGD: """ Creates a Stochastic Gradient Descent optimizer from the PyTorch `torch.optim` module using the input parameters. @@ -26,7 +27,7 @@ def sgd(parameters): """ # Create the optimizer using the input parameters - optimizer = SGD( + return SGD( parameters["model_parameters"], lr=parameters.get("learning_rate"), momentum=parameters["optimizer"].get("momentum", 0.99), @@ -35,10 +36,8 @@ def sgd(parameters): nesterov=parameters["optimizer"].get("nesterov", True), ) - return optimizer - -def asgd(parameters): +def asgd(parameters: dict) -> torch.optim.ASGD: """ Creates an Averaged Stochastic Gradient Descent optimizer from the PyTorch `torch.optim` module using the input parameters. @@ -60,7 +59,7 @@ def asgd(parameters): ) -def adam(parameters, opt_type="normal"): +def adam(parameters: dict, opt_type: str = "normal") -> torch.optim.Adam: """ Creates an Adam or AdamW optimizer from the PyTorch `torch.optim` module using the input parameters. @@ -91,7 +90,7 @@ def adam(parameters, opt_type="normal"): ) -def adamw(parameters): +def adamw(parameters: dict) -> torch.optim.AdamW: """ Creates an AdamW optimizer from the PyTorch `torch.optim` module using the input parameters. @@ -105,7 +104,7 @@ def adamw(parameters): return adam(parameters, opt_type="AdamW") -def adamax(parameters): +def adamax(parameters: dict) -> torch.optim.Adamax: """ Creates an Adamax optimizer from the PyTorch `torch.optim` module using the input parameters. @@ -141,7 +140,7 @@ def adamax(parameters): # ) -def rprop(parameters): +def rprop(parameters: dict) -> torch.optim.Rprop: """ Creates a Resilient Backpropagation optimizer from the PyTorch `torch.optim` module using the input parameters. @@ -161,7 +160,7 @@ def rprop(parameters): ) -def adadelta(parameters): +def adadelta(parameters: dict) -> torch.optim.Adadelta: """ Creates an Adadelta optimizer from the PyTorch `torch.optim` module using the input parameters. @@ -182,7 +181,7 @@ def adadelta(parameters): ) -def adagrad(parameters): +def adagrad(parameters: dict) -> torch.optim.Adagrad: """ Creates an Adagrad optimizer from the PyTorch `torch.optim` module using the input parameters. @@ -204,7 +203,7 @@ def adagrad(parameters): ) -def rmsprop(parameters): +def rmsprop(parameters: dict) -> torch.optim.RMSprop: """ Creates an RMSprop optimizer from the PyTorch `torch.optim` module using the input parameters. @@ -227,7 +226,7 @@ def rmsprop(parameters): ) -def radam(parameters): +def radam(parameters: dict) -> torch.optim.RAdam: """ Creates a RAdam optimizer from the PyTorch `torch.optim` module using the input parameters. @@ -248,7 +247,7 @@ def radam(parameters): ) -def nadam(parameters): +def nadam(parameters: dict) -> torch.optim.NAdam: """ Creates a NAdam optimizer from the PyTorch `torch.optim` module using the input parameters.