diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_delta_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_delta_cpu_kernel.h index 4ac7df24caa..890977bc7e9 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_delta_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_delta_cpu_kernel.h @@ -36,7 +36,7 @@ class AdamDeltaCPUKernel : public CPUKernel { size_t elem_num_{0}; }; -MS_REG_CPU_KERNEL(AdamDelta, +MS_REG_CPU_KERNEL(AdamNoUpdateParam, KernelAttr() .AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32) @@ -47,7 +47,6 @@ MS_REG_CPU_KERNEL(AdamDelta, .AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32), AdamDeltaCPUKernel); } // namespace kernel diff --git a/mindspore/nn/optim/__init__.py b/mindspore/nn/optim/__init__.py index e31e6345d41..70932aac705 100644 --- a/mindspore/nn/optim/__init__.py +++ b/mindspore/nn/optim/__init__.py @@ -20,7 +20,7 @@ The optimizer is used to calculate and update the gradients. """ from .optimizer import Optimizer from .momentum import Momentum -from .adam import Adam, AdamWeightDecay +from .adam import Adam, AdamWeightDecay, AdamOffload from .lamb import Lamb from .sgd import SGD from .lars import LARS @@ -29,5 +29,5 @@ from .rmsprop import RMSProp from .proximal_ada_grad import ProximalAdagrad from .lazyadam import LazyAdam -__all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'AdamWeightDecay', 'LazyAdam', +__all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'AdamWeightDecay', 'LazyAdam', 'AdamOffload', 'Lamb', 'SGD', 'FTRL', 'RMSProp', 'ProximalAdagrad'] diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index 27ea0c56b26..c023cf86bd0 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -90,22 +90,22 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d @_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool") def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, - beta2_power, beta1, beta2, eps, lr, gradient, params, m, v, ps_parameter): + beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter): """Apply sparse adam optimizer to the weight parameter when the gradient is sparse.""" success = True indices = gradient.indices values = gradient.values if ps_parameter: op_shape = P.Shape() - shapes = (op_shape(params), op_shape(m), op_shape(v), + shapes = (op_shape(param), op_shape(m), op_shape(v), op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1), op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices)) success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, - eps, values, indices), shapes), params)) + eps, values, indices), shapes), param)) return success if not target: - success = F.depend(success, sparse_opt(params, m, v, beta1_power, beta2_power, lr, beta1, beta2, + success = F.depend(success, sparse_opt(param, m, v, beta1_power, beta2_power, lr, beta1, beta2, eps, values, indices)) else: op_mul = P.Mul() @@ -145,12 +145,12 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power) - next_param = params - lr_t * param_update + next_param = param - lr_t * param_update F.control_depend(assign_m, next_m) F.control_depend(assign_v, next_v) - success = F.depend(success, F.assign(params, next_param)) + success = F.depend(success, F.assign(param, next_param)) success = F.depend(success, F.assign(m, next_m)) success = F.depend(success, F.assign(v, next_v)) @@ -160,18 +160,29 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, @_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool") def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, - beta2_power, beta1, beta2, eps, lr, gradient, params, moment1, moment2, ps_parameter): + beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2, ps_parameter): """Apply adam optimizer to the weight parameter using Tensor.""" success = True if ps_parameter: op_shape = P.Shape() success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient), - (op_shape(params), op_shape(moment1), op_shape(moment2))), params)) + (op_shape(param), op_shape(moment1), op_shape(moment2))), param)) else: - success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, + success = F.depend(success, opt(param, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, eps, gradient)) return success + +@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", + "Tensor", "Tensor") +def _run_off_load_opt(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2): + """Apply AdamOffload optimizer to the weight parameter using Tensor.""" + success = True + delat_param = opt(moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, eps, gradient) + success = F.depend(success, F.assign_add(param, delat_param)) + return success + + def _check_param_value(beta1, beta2, eps, prim_name): """Check the type of inputs.""" validator.check_value_type("beta1", beta1, [float], prim_name) @@ -443,3 +454,146 @@ class AdamWeightDecay(Optimizer): if self.use_parallel: self.broadcast_params(optim_result) return optim_result + + +class AdamOffload(Optimizer): + r""" + Updates gradients by the Adaptive Moment Estimation (Adam) algorithm. This optimizer will offload Adam optimizer to + host CPU and keep parameters being updated on the device, to minimize the memory cost. Although that would bring + about an increase of performance overhead, the optimizer could be used to run a larger model. + + The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization `_. + + The updating formulas are as follows, + + .. math:: + \begin{array}{ll} \\ + m = \beta_1 * m + (1 - \beta_1) * g \\ + v = \beta_2 * v + (1 - \beta_2) * g * g \\ + l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\ + w = w - l * \frac{m}{\sqrt{v} + \epsilon} + \end{array} + + :math:`m` represents the 1st moment vector `moment1`, :math:`v` represents the 2nd moment vector `moment2`, + :math:`g` represents `gradients`, :math:`l` represents scaling factor `lr`, :math:`\beta_1, \beta_2` represent + `beta1` and `beta2`, :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent + `beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `params`, + :math:`\epsilon` represents `eps`. + + Note: + This optimizer only supports `GRAPH_MODE` currently. + + When separating parameter groups, the weight decay in each group will be applied on the parameters if the + weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied + on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. + + To improve parameter groups performance, the customized order of parameters is supported. + + Args: + params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, + the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params", + "lr", "weight_decay" and "order_params" are the keys can be parsed. + + - params: Required. The value must be a list of `Parameter`. + + - lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used. + If not, the `learning_rate` in the API will be used. + + - weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay + will be used. If not, the `weight_decay` in the API will be used. + + - order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and + the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters + which in the 'order_params' must be in one of group parameters. + + learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. + When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then + the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, + use dynamic learning rate, the i-th learning rate will be calculated during the process of training + according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero + dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be + equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. + Default: 1e-3. + beta1 (float): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0). + Default: 0.9. + beta2 (float): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0). + Default: 0.999. + eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default: + 1e-8. + use_locking (bool): Whether to enable a lock to protect variable tensors from being updated. + If true, updates of the var, m, and v tensors will be protected by a lock. + If false, the result is unpredictable. Default: False. + use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients. + If true, update the gradients using NAG. + If false, update the gradients without using NAG. Default: False. + weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0. + loss_scale (float): A floating point value for the loss scale. Should be greater than 0. Default: 1.0. + + Inputs: + - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. + + Outputs: + Tensor[bool], the value is True. + + Examples: + >>> net = Net() + >>> #1) All parameters use the same learning rate and weight decay + >>> optim = nn.AdamOffload(params=net.trainable_params()) + >>> + >>> #2) Use parameter groups and set different values + >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, + >>> {'params': no_conv_params, 'lr': 0.01}, + >>> {'order_params': net.trainable_params()}] + >>> optim = nn.AdamOffload(group_params, learning_rate=0.1, weight_decay=0.0) + >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. + >>> # The no_conv_params's parameters will use learning rate of 0.01 and defaule weight decay of 0.0. + >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. + >>> + >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> model = Model(net, loss_fn=loss, optimizer=optim) + """ + + def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False, + use_nesterov=False, weight_decay=0.0, loss_scale=1.0): + super(AdamOffload, self).__init__(learning_rate, params, weight_decay, loss_scale) + _check_param_value(beta1, beta2, eps, self.cls_name) + validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) + validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name) + + self.beta1 = Tensor(beta1, mstype.float32) + self.beta2 = Tensor(beta2, mstype.float32) + self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power") + self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power") + self.eps = Tensor(eps, mstype.float32) + self.use_nesterov = use_nesterov + self.use_locking = use_locking + self.moment1 = self.parameters.clone(prefix="moment1", init='zeros') + self.moment2 = self.parameters.clone(prefix="moment2", init='zeros') + + self.hyper_map = C.HyperMap() + self.opt = P.AdamNoUpdateParam(use_locking, use_nesterov) + self.opt.add_prim_attr("primitive_target", "CPU") + + def construct(self, gradients): + params = self.parameters + moment1 = self.moment1 + moment2 = self.moment2 + gradients = self.decay_weight(gradients) + gradients = self.scale_grad(gradients) + lr = self.get_lr() + + beta1_power = self.beta1_power * self.beta1 + self.beta1_power = beta1_power + beta2_power = self.beta2_power * self.beta2 + self.beta2_power = beta2_power + if self.is_group_lr: + success = self.map_(F.partial(_adam_opt, self.opt, + beta1_power, beta2_power, self.beta1, self.beta2, self.eps), + lr, gradients, params, moment1, moment2) + else: + success = self.map_(F.partial(_adam_opt, self.opt, + beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr), + gradients, params, moment1, moment2) + return success diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index c003e992b95..74107fbb50a 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -58,7 +58,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, RandomCategorical, StandardLaplace, Multinomial) -from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, ApplyMomentum, BatchNorm, +from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam, ApplyMomentum, BatchNorm, BiasAdd, Conv2D, DepthwiseConv2dNative, DropoutDoMask, Dropout, @@ -133,6 +133,7 @@ __all__ = [ 'Adam', 'FusedSparseAdam', 'FusedSparseLazyAdam', + 'AdamNoUpdateParam', 'Softplus', 'Softmax', 'Softsign', diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 0d40a620281..63b6dc50faf 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -3309,6 +3309,107 @@ class Adam(PrimitiveWithInfer): return var_dtype, m_dtype, v_dtype +class AdamNoUpdateParam(PrimitiveWithInfer): + r""" + Updates gradients by Adaptive Moment Estimation (Adam) algorithm. This operator do not update the parameter, but + calculate the value that should be added to the parameter instead. + + The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization `_. + + The updating formulas are as follows, + + .. math:: + \begin{array}{ll} \\ + m = \beta_1 * m + (1 - \beta_1) * g \\ + v = \beta_2 * v + (1 - \beta_2) * g * g \\ + l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\ + \Delta{w} = - l * \frac{m}{\sqrt{v} + \epsilon} + \end{array} + + :math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents + `gradient`, :math:`l` represents scaling factor `lr`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`, + :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent `beta1_power` and + `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents the parameter to be updated, + :math:`\epsilon`represents `epsilon`. + + Args: + use_locking (bool): Whether to enable a lock to protect variable tensors from being updated. + If true, updates of the var, m, and v tensors will be protected by a lock. + If false, the result is unpredictable. Default: False. + use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients. + If true, update the gradients using NAG. + If false, update the gradients without using NAG. Default: False. + + Inputs: + - **m** (Tensor) - The 1st moment vector in the updating formula. The data type must be float32. + - **v** (Tensor) - the 2nd moment vector in the updating formula. The shape must be the same as `m`. + The data type must be float32. + - **beta1_power** (Tensor) - :math:`beta_1^t` in the updating formula. The data type must be float32. + - **beta2_power** (Tensor) - :math:`beta_2^t` in the updating formula. The data type must be float32. + - **lr** (Tensor) - :math:`l` in the updating formula. The data type must be float32. + - **beta1** (Tensor) - The exponential decay rate for the 1st moment estimations. The data type must be float32. + - **beta2** (Tensor) - The exponential decay rate for the 2nd moment estimations. The data type must be float32. + - **epsilon** (Tensor) - Term added to the denominator to improve numerical stability. The data type must be + float32. + - **gradient** (Tensor) - Gradient, the shape must be the same as `m`, the data type must be float32. + + Outputs: + Tensor, whose shape and data type are the same with `gradient`, is a value that should be added to the + parameter to be updated. + + Examples: + >>> import numpy as np + >>> import mindspore as ms + >>> import mindspore.nn as nn + >>> from mindspore import Tensor, Parameter + >>> from mindspore.ops import operations as P + >>> + >>> class Net(nn.Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self.adam = P.AdamNoUpdateParam() + >>> self.m = Parameter(Tensor(np.array([[0.1, 0.1, 0.1], [0.2, 0.2, 0.2]]).astype(np.float32)), + >>> name="m") + >>> self.v = Parameter(Tensor(np.array([[0.1, 0.1, 0.1], [0.2, 0.2, 0.2]]).astype(np.float32)), + >>> name="v") + >>> def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad): + >>> out = self.adam(self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad) + >>> return out + >>> net = Net() + >>> beta1_power = Tensor(0.9, ms.float32) + >>> beta2_power = Tensor(0.999, ms.float32) + >>> lr = Tensor(0.001, ms.float32) + >>> beta1 = Tensor(0.9, ms.float32) + >>> beta2 = Tensor(0.999, ms.float32) + >>> epsilon = Tensor(1e-8, ms.float32) + >>> gradient = Tensor(np.array([[0.1, 0.1, 0.1], [0.1, 0.1, 0.1]]).astype(np.float32)) + >>> + >>> result = net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, gradient) + >>> print(result) + [[-0.00010004 -0.00010004 -0.00010004] + [-0.00013441 -0.00013441 -0.00013441]] + + """ + @prim_attr_register + def __init__(self, use_locking=False, use_nesterov=False): + validator.check_value_type("use_locking", use_locking, [bool], self.name) + validator.check_value_type("use_nesterov", use_nesterov, [bool], self.name) + + def infer_shape(self, m_shape, v_shape, beta1_power_shape, beta2_power_shape, lr_shape, + beta1_shape, beta2_shape, epsilon_shape, grad_shape): + validator.check("grad_shape", grad_shape, "m_shape", m_shape, Rel.EQ, self.name) + validator.check("grad_shape", grad_shape, "v_shape", v_shape, Rel.EQ, self.name) + return grad_shape + + def infer_dtype(self, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype, + beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype): + args = {"m": m_dtype, "v": v_dtype, "grad": grad_dtype, + "beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype, + "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype} + validator.check_tensors_dtypes_same_and_valid(args, [mstype.float32], self.name) + return grad_dtype + + class FusedSparseAdam(PrimitiveWithInfer): r""" Merges the duplicate value of the gradient and then updates parameters by Adaptive Moment Estimation (Adam) diff --git a/tests/ut/python/nn/optim/test_adam.py b/tests/ut/python/nn/optim/test_adam.py index e8e6bd9cacd..687a1363aa3 100644 --- a/tests/ut/python/nn/optim/test_adam.py +++ b/tests/ut/python/nn/optim/test_adam.py @@ -183,6 +183,24 @@ def test_adamweightdecay_group(): _executor.compile(train_network, inputs, label) +def test_adamoffload_group(): + """ test_adam_group_lr_and_weight_decay """ + inputs = Tensor(np.ones([1, 64]).astype(np.float32)) + label = Tensor(np.zeros([1, 10]).astype(np.float32)) + net = Net() + net.set_train() + loss = nn.SoftmaxCrossEntropyWithLogits() + net_with_loss = WithLossCell(net, loss) + all_params = net.trainable_params() + + schedule_lr = lr_schedules.PolynomialDecayLR(0.01, 0.0001, 3, power=1.0) + group_params = [{'params': [all_params[0]], 'lr': 0.02, 'weight_decay': 0.9}, + {'params': [all_params[1]]}] + optimizer = nn.AdamOffload(group_params, learning_rate=schedule_lr) + train_network = TrainOneStepCell(net_with_loss, optimizer) + _executor.compile(train_network, inputs, label) + + def test_AdamWeightDecay_beta1(): net = Net() print("**********", net.get_parameters())