Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[FEATURE] RAdam optimizer implementation #20762

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions python/mxnet/optimizer/radam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# coding: utf-8
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""RAdam optimizer."""
from __future__ import absolute_import
from .optimizer import Optimizer, register
from ..ndarray import (zeros, clip, sqrt, square, full, NDArray)

__all__ = ['RAdam']


@register
class RAdam(Optimizer):
"""The RAdam optimizer.

This class implements the optimizer described in *On the Variance of the Adaptive Learning Rate and Beyond*,
available at https://arxiv.org/pdf/1908.03265.pdf.

Updates are applied by::

grad = clip(grad * rescale_grad, clip_gradient)
m = beta1 * m + (1 - beta1) * grad
v = beta2 * v + (1 - beta2) * (grad**2)
m_hat = m / (1 - beta1)
p = p_inf - (2 * step * beta2) / (1 - beta2)

If p >= 5::

lr_a = sqrt((1 - beta2) / (v + epsilon))
r = sqrt(((p - 4) * (p - 2) * p_inf) / ((p_inf - 4) * (p_inf - 2) * p))
w = w - (lr * m_hat * r * lr_a)

If p < 5::

w = w - (lr * m_hat)

This optimizer accepts the following parameters in addition to those accepted
by :class:`.Optimizer`.

Parameters
----------
learning_rate : float, default 0.001
The initial learning rate. If None, the optimization will use the
learning rate from ``lr_scheduler``. If not None, it will overwrite
the learning rate in ``lr_scheduler``. If None and ``lr_scheduler``
is also None, then it will be set to 0.01 by default.
beta1 : float, default 0.9
Exponential decay rate for the first moment estimates.
beta2 : float, default 0.999
Exponential decay rate for the second moment estimates.
epsilon : float, default 1e-8
Small value to avoid division by 0.
use_fused_step : bool, default False
Whether or not to use fused kernels for optimizer.
When use_fused_step=False, step is called,
otherwise, fused_step is called.
"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
use_fused_step=False, **kwargs):
super(RAdam, self).__init__(use_fused_step=use_fused_step,
learning_rate=learning_rate,
**kwargs)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon

def create_state(self, index, weight):
"""state creation function."""
return (zeros(weight.shape, weight.context, dtype=weight.dtype), # mean
zeros(weight.shape, weight.context, dtype=weight.dtype)) # variance

def step(self, indices, weights, grads, states):
"""Perform an optimization step using gradients and states.

Parameters
----------
indices : list of int
List of unique indices of the parameters into the individual learning rates
and weight decays. Learning rates and weight decay may be set via `set_lr_mult()`
and `set_wd_mult()`, respectively.
weights : list of NDArray
List of parameters to be updated.
grads : list of NDArray
List of gradients of the objective with respect to this parameter.
states : List of any obj
List of state returned by `create_state()`.
"""
for index, weight, grad, state in zip(indices, weights, grads, states):
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
t = self._index_update_count[index]

bias_correction1 = 1 - self.beta1 ** t
bias_correction2 = 1 - self.beta2 ** t

# preprocess grad
grad *= self.rescale_grad
if self.clip_gradient is not None:
grad = clip(grad, - self.clip_gradient, self.clip_gradient)
grad += wd * weight

# update mean and var
mean, var = state
mean[:] *= self.beta1
mean[:] += (1. - self.beta1) * grad
var[:] *= self.beta2
var[:] += (1. - self.beta2) * square(grad)

bias_corrected_mean = mean / bias_correction1

# maximum length of the approximated SMA
rho_inf = 2 / (1 - self.beta2) - 1
# compute the length of the approximated SMA
rho_t = rho_inf - 2 * t * (self.beta2 ** t) / bias_correction2

#update weight
if rho_t >= 5:
# compute the variance rectification term and update parameters accordingly
rect = sqrt((rho_t - 4) * (rho_t - 2) * rho_inf / ((rho_inf - 4) * (rho_inf - 2) * rho_t))
adaptive_lr = sqrt(bias_correction2) / (sqrt(var) + self.epsilon)
weight[:] += bias_corrected_mean * lr * adaptive_lr * rect * -1.0
else:
weight[:] += bias_corrected_mean * lr * -1.0
20 changes: 20 additions & 0 deletions tests/python/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,26 @@ def test_adamW():
opt2(use_fused_step=True, **kwarg), shapes, dtype,
rtol=1e-3, atol=2e-3)

def test_radam():
opt1 = mx.optimizer.RAdam
opt2 = mx.optimizer.RAdam
shapes = [(3, 4, 5), (10, 4), (7,)]
beta1_options = [{}, {'beta1': 0.5}, {'beta1': 0.7}]
beta2_options = [{}, {'beta2': 0.8}, {'beta2': 0.9}]
cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}]
mp_options = [{'multi_precision': False}, {'multi_precision': True}]
agg_options = [{'aggregate_num': 0}, {'aggregate_num': 1},
{'aggregate_num': 4}, {'aggregate_num': np.inf}]
for dtype in [np.float16, np.float32]:
for params in itertools.product(beta1_options, beta2_options, cg_options,
rg_options, wd_options, mp_options, agg_options):
kwarg = {k: v for param in params for k, v in param.items()}
if (dtype == np.float16 and ('multi_precision' not in kwarg or not kwarg['multi_precision'])):
continue
compare_optimizer(opt1(**kwarg), opt2(**kwarg), shapes, dtype)

def test_adabelief():
opt1 = mx.optimizer.AdaBelief
opt2 = mx.optimizer.AdaBelief
Expand Down