add AdamOffload optimizer

This commit is contained in:
wangnan39@huawei.com 2020-11-17 16:23:45 +08:00
parent e1cfeeb1dd
commit ab811fca8f
6 changed files with 287 additions and 14 deletions

View File

@ -36,7 +36,7 @@ class AdamDeltaCPUKernel : public CPUKernel {
size_t elem_num_{0};
};
MS_REG_CPU_KERNEL(AdamDelta,
MS_REG_CPU_KERNEL(AdamNoUpdateParam,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
@ -47,7 +47,6 @@ MS_REG_CPU_KERNEL(AdamDelta,
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
AdamDeltaCPUKernel);
} // namespace kernel

View File

@ -20,7 +20,7 @@ The optimizer is used to calculate and update the gradients.
"""
from .optimizer import Optimizer
from .momentum import Momentum
from .adam import Adam, AdamWeightDecay
from .adam import Adam, AdamWeightDecay, AdamOffload
from .lamb import Lamb
from .sgd import SGD
from .lars import LARS
@ -29,5 +29,5 @@ from .rmsprop import RMSProp
from .proximal_ada_grad import ProximalAdagrad
from .lazyadam import LazyAdam
__all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'AdamWeightDecay', 'LazyAdam',
__all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'AdamWeightDecay', 'LazyAdam', 'AdamOffload',
'Lamb', 'SGD', 'FTRL', 'RMSProp', 'ProximalAdagrad']

View File

@ -90,22 +90,22 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d
@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool")
def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
beta2_power, beta1, beta2, eps, lr, gradient, params, m, v, ps_parameter):
beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter):
"""Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
success = True
indices = gradient.indices
values = gradient.values
if ps_parameter:
op_shape = P.Shape()
shapes = (op_shape(params), op_shape(m), op_shape(v),
shapes = (op_shape(param), op_shape(m), op_shape(v),
op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1),
op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices))
success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2,
eps, values, indices), shapes), params))
eps, values, indices), shapes), param))
return success
if not target:
success = F.depend(success, sparse_opt(params, m, v, beta1_power, beta2_power, lr, beta1, beta2,
success = F.depend(success, sparse_opt(param, m, v, beta1_power, beta2_power, lr, beta1, beta2,
eps, values, indices))
else:
op_mul = P.Mul()
@ -145,12 +145,12 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov,
lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power)
next_param = params - lr_t * param_update
next_param = param - lr_t * param_update
F.control_depend(assign_m, next_m)
F.control_depend(assign_v, next_v)
success = F.depend(success, F.assign(params, next_param))
success = F.depend(success, F.assign(param, next_param))
success = F.depend(success, F.assign(m, next_m))
success = F.depend(success, F.assign(v, next_v))
@ -160,18 +160,29 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov,
@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
beta2_power, beta1, beta2, eps, lr, gradient, params, moment1, moment2, ps_parameter):
beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2, ps_parameter):
"""Apply adam optimizer to the weight parameter using Tensor."""
success = True
if ps_parameter:
op_shape = P.Shape()
success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient),
(op_shape(params), op_shape(moment1), op_shape(moment2))), params))
(op_shape(param), op_shape(moment1), op_shape(moment2))), param))
else:
success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
success = F.depend(success, opt(param, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
eps, gradient))
return success
@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor")
def _run_off_load_opt(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2):
"""Apply AdamOffload optimizer to the weight parameter using Tensor."""
success = True
delat_param = opt(moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, eps, gradient)
success = F.depend(success, F.assign_add(param, delat_param))
return success
def _check_param_value(beta1, beta2, eps, prim_name):
"""Check the type of inputs."""
validator.check_value_type("beta1", beta1, [float], prim_name)
@ -443,3 +454,146 @@ class AdamWeightDecay(Optimizer):
if self.use_parallel:
self.broadcast_params(optim_result)
return optim_result
class AdamOffload(Optimizer):
r"""
Updates gradients by the Adaptive Moment Estimation (Adam) algorithm. This optimizer will offload Adam optimizer to
host CPU and keep parameters being updated on the device, to minimize the memory cost. Although that would bring
about an increase of performance overhead, the optimizer could be used to run a larger model.
The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
The updating formulas are as follows,
.. math::
\begin{array}{ll} \\
m = \beta_1 * m + (1 - \beta_1) * g \\
v = \beta_2 * v + (1 - \beta_2) * g * g \\
l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\
w = w - l * \frac{m}{\sqrt{v} + \epsilon}
\end{array}
:math:`m` represents the 1st moment vector `moment1`, :math:`v` represents the 2nd moment vector `moment2`,
:math:`g` represents `gradients`, :math:`l` represents scaling factor `lr`, :math:`\beta_1, \beta_2` represent
`beta1` and `beta2`, :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent
`beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `params`,
:math:`\epsilon` represents `eps`.
Note:
This optimizer only supports `GRAPH_MODE` currently.
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
To improve parameter groups performance, the customized order of parameters is supported.
Args:
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
"lr", "weight_decay" and "order_params" are the keys can be parsed.
- params: Required. The value must be a list of `Parameter`.
- lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used.
If not, the `learning_rate` in the API will be used.
- weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay
will be used. If not, the `weight_decay` in the API will be used.
- order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and
the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
which in the 'order_params' must be in one of group parameters.
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
use dynamic learning rate, the i-th learning rate will be calculated during the process of training
according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
Default: 1e-3.
beta1 (float): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
Default: 0.9.
beta2 (float): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
Default: 0.999.
eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default:
1e-8.
use_locking (bool): Whether to enable a lock to protect variable tensors from being updated.
If true, updates of the var, m, and v tensors will be protected by a lock.
If false, the result is unpredictable. Default: False.
use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
If true, update the gradients using NAG.
If false, update the gradients without using NAG. Default: False.
weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
loss_scale (float): A floating point value for the loss scale. Should be greater than 0. Default: 1.0.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs:
Tensor[bool], the value is True.
Examples:
>>> net = Net()
>>> #1) All parameters use the same learning rate and weight decay
>>> optim = nn.AdamOffload(params=net.trainable_params())
>>>
>>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
>>> {'params': no_conv_params, 'lr': 0.01},
>>> {'order_params': net.trainable_params()}]
>>> optim = nn.AdamOffload(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
>>> # The no_conv_params's parameters will use learning rate of 0.01 and defaule weight decay of 0.0.
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = Model(net, loss_fn=loss, optimizer=optim)
"""
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
use_nesterov=False, weight_decay=0.0, loss_scale=1.0):
super(AdamOffload, self).__init__(learning_rate, params, weight_decay, loss_scale)
_check_param_value(beta1, beta2, eps, self.cls_name)
validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name)
self.beta1 = Tensor(beta1, mstype.float32)
self.beta2 = Tensor(beta2, mstype.float32)
self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power")
self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power")
self.eps = Tensor(eps, mstype.float32)
self.use_nesterov = use_nesterov
self.use_locking = use_locking
self.moment1 = self.parameters.clone(prefix="moment1", init='zeros')
self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
self.hyper_map = C.HyperMap()
self.opt = P.AdamNoUpdateParam(use_locking, use_nesterov)
self.opt.add_prim_attr("primitive_target", "CPU")
def construct(self, gradients):
params = self.parameters
moment1 = self.moment1
moment2 = self.moment2
gradients = self.decay_weight(gradients)
gradients = self.scale_grad(gradients)
lr = self.get_lr()
beta1_power = self.beta1_power * self.beta1
self.beta1_power = beta1_power
beta2_power = self.beta2_power * self.beta2
self.beta2_power = beta2_power
if self.is_group_lr:
success = self.map_(F.partial(_adam_opt, self.opt,
beta1_power, beta2_power, self.beta1, self.beta2, self.eps),
lr, gradients, params, moment1, moment2)
else:
success = self.map_(F.partial(_adam_opt, self.opt,
beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr),
gradients, params, moment1, moment2)
return success

View File

@ -58,7 +58,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal,
RandomCategorical, StandardLaplace, Multinomial)
from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, ApplyMomentum, BatchNorm,
from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam, ApplyMomentum, BatchNorm,
BiasAdd, Conv2D,
DepthwiseConv2dNative,
DropoutDoMask, Dropout,
@ -133,6 +133,7 @@ __all__ = [
'Adam',
'FusedSparseAdam',
'FusedSparseLazyAdam',
'AdamNoUpdateParam',
'Softplus',
'Softmax',
'Softsign',

View File

@ -3309,6 +3309,107 @@ class Adam(PrimitiveWithInfer):
return var_dtype, m_dtype, v_dtype
class AdamNoUpdateParam(PrimitiveWithInfer):
r"""
Updates gradients by Adaptive Moment Estimation (Adam) algorithm. This operator do not update the parameter, but
calculate the value that should be added to the parameter instead.
The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
The updating formulas are as follows,
.. math::
\begin{array}{ll} \\
m = \beta_1 * m + (1 - \beta_1) * g \\
v = \beta_2 * v + (1 - \beta_2) * g * g \\
l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\
\Delta{w} = - l * \frac{m}{\sqrt{v} + \epsilon}
\end{array}
:math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents
`gradient`, :math:`l` represents scaling factor `lr`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`,
:math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent `beta1_power` and
`beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents the parameter to be updated,
:math:`\epsilon`represents `epsilon`.
Args:
use_locking (bool): Whether to enable a lock to protect variable tensors from being updated.
If true, updates of the var, m, and v tensors will be protected by a lock.
If false, the result is unpredictable. Default: False.
use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
If true, update the gradients using NAG.
If false, update the gradients without using NAG. Default: False.
Inputs:
- **m** (Tensor) - The 1st moment vector in the updating formula. The data type must be float32.
- **v** (Tensor) - the 2nd moment vector in the updating formula. The shape must be the same as `m`.
The data type must be float32.
- **beta1_power** (Tensor) - :math:`beta_1^t` in the updating formula. The data type must be float32.
- **beta2_power** (Tensor) - :math:`beta_2^t` in the updating formula. The data type must be float32.
- **lr** (Tensor) - :math:`l` in the updating formula. The data type must be float32.
- **beta1** (Tensor) - The exponential decay rate for the 1st moment estimations. The data type must be float32.
- **beta2** (Tensor) - The exponential decay rate for the 2nd moment estimations. The data type must be float32.
- **epsilon** (Tensor) - Term added to the denominator to improve numerical stability. The data type must be
float32.
- **gradient** (Tensor) - Gradient, the shape must be the same as `m`, the data type must be float32.
Outputs:
Tensor, whose shape and data type are the same with `gradient`, is a value that should be added to the
parameter to be updated.
Examples:
>>> import numpy as np
>>> import mindspore as ms
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter
>>> from mindspore.ops import operations as P
>>>
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.adam = P.AdamNoUpdateParam()
>>> self.m = Parameter(Tensor(np.array([[0.1, 0.1, 0.1], [0.2, 0.2, 0.2]]).astype(np.float32)),
>>> name="m")
>>> self.v = Parameter(Tensor(np.array([[0.1, 0.1, 0.1], [0.2, 0.2, 0.2]]).astype(np.float32)),
>>> name="v")
>>> def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
>>> out = self.adam(self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
>>> return out
>>> net = Net()
>>> beta1_power = Tensor(0.9, ms.float32)
>>> beta2_power = Tensor(0.999, ms.float32)
>>> lr = Tensor(0.001, ms.float32)
>>> beta1 = Tensor(0.9, ms.float32)
>>> beta2 = Tensor(0.999, ms.float32)
>>> epsilon = Tensor(1e-8, ms.float32)
>>> gradient = Tensor(np.array([[0.1, 0.1, 0.1], [0.1, 0.1, 0.1]]).astype(np.float32))
>>>
>>> result = net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, gradient)
>>> print(result)
[[-0.00010004 -0.00010004 -0.00010004]
[-0.00013441 -0.00013441 -0.00013441]]
"""
@prim_attr_register
def __init__(self, use_locking=False, use_nesterov=False):
validator.check_value_type("use_locking", use_locking, [bool], self.name)
validator.check_value_type("use_nesterov", use_nesterov, [bool], self.name)
def infer_shape(self, m_shape, v_shape, beta1_power_shape, beta2_power_shape, lr_shape,
beta1_shape, beta2_shape, epsilon_shape, grad_shape):
validator.check("grad_shape", grad_shape, "m_shape", m_shape, Rel.EQ, self.name)
validator.check("grad_shape", grad_shape, "v_shape", v_shape, Rel.EQ, self.name)
return grad_shape
def infer_dtype(self, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype,
beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype):
args = {"m": m_dtype, "v": v_dtype, "grad": grad_dtype,
"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype,
"beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype}
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float32], self.name)
return grad_dtype
class FusedSparseAdam(PrimitiveWithInfer):
r"""
Merges the duplicate value of the gradient and then updates parameters by Adaptive Moment Estimation (Adam)

View File

@ -183,6 +183,24 @@ def test_adamweightdecay_group():
_executor.compile(train_network, inputs, label)
def test_adamoffload_group():
""" test_adam_group_lr_and_weight_decay """
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
net_with_loss = WithLossCell(net, loss)
all_params = net.trainable_params()
schedule_lr = lr_schedules.PolynomialDecayLR(0.01, 0.0001, 3, power=1.0)
group_params = [{'params': [all_params[0]], 'lr': 0.02, 'weight_decay': 0.9},
{'params': [all_params[1]]}]
optimizer = nn.AdamOffload(group_params, learning_rate=schedule_lr)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
def test_AdamWeightDecay_beta1():
net = Net()
print("**********", net.get_parameters())