forked from mindspore-Ecosystem/mindspore
add RMSProp optimizer
This commit is contained in:
parent
9a717aa1f7
commit
dcd1f0a504
|
@ -184,6 +184,8 @@ const char kNameDiagPart[] = "DiagPart";
|
||||||
const char kNameSpaceToBatch[] = "SpaceToBatch";
|
const char kNameSpaceToBatch[] = "SpaceToBatch";
|
||||||
const char kNameBatchToSpace[] = "BatchToSpace";
|
const char kNameBatchToSpace[] = "BatchToSpace";
|
||||||
const char kNameAtan2[] = "Atan2";
|
const char kNameAtan2[] = "Atan2";
|
||||||
|
const char kNameApplyRMSProp[] = "ApplyRMSProp";
|
||||||
|
const char kNameApplyCenteredRMSProp[] = "ApplyCenteredRMSProp";
|
||||||
|
|
||||||
// -----------------OpAdapter initialization--------------
|
// -----------------OpAdapter initialization--------------
|
||||||
std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_map() {
|
std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_map() {
|
||||||
|
@ -369,7 +371,9 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
|
||||||
{string(kNameDiagPart), ADPT_DESC(DiagPart)},
|
{string(kNameDiagPart), ADPT_DESC(DiagPart)},
|
||||||
{string(kNameSpaceToBatch), ADPT_DESC(SpaceToBatchD)},
|
{string(kNameSpaceToBatch), ADPT_DESC(SpaceToBatchD)},
|
||||||
{string(kNameBatchToSpace), ADPT_DESC(BatchToSpaceD)},
|
{string(kNameBatchToSpace), ADPT_DESC(BatchToSpaceD)},
|
||||||
{string(kNameAtan2), ADPT_DESC(Atan2)}};
|
{string(kNameAtan2), ADPT_DESC(Atan2)},
|
||||||
|
{string(kNameApplyRMSProp), ADPT_DESC(ApplyRMSPropD)},
|
||||||
|
{string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)}};
|
||||||
#ifdef ENABLE_GE
|
#ifdef ENABLE_GE
|
||||||
adpt_map[string(kNamePrint)] = ADPT_DESC(Print);
|
adpt_map[string(kNamePrint)] = ADPT_DESC(Print);
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -1189,6 +1189,22 @@ INPUT_MAP(Atan2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
|
||||||
ATTR_MAP(Atan2) = EMPTY_ATTR_MAP;
|
ATTR_MAP(Atan2) = EMPTY_ATTR_MAP;
|
||||||
OUTPUT_MAP(Atan2) = {{0, OUTPUT_DESC(y)}};
|
OUTPUT_MAP(Atan2) = {{0, OUTPUT_DESC(y)}};
|
||||||
|
|
||||||
|
// ApplyRMSPropD
|
||||||
|
INPUT_MAP(ApplyRMSPropD) = {
|
||||||
|
{1, INPUT_DESC(var)}, {2, INPUT_DESC(ms)}, {3, INPUT_DESC(mom)}, {4, INPUT_DESC(grad)}, {5, INPUT_DESC(lr)}};
|
||||||
|
INPUT_ATTR_MAP(ApplyRMSPropD) = {{6, ATTR_DESC(rho, AnyTraits<float>())},
|
||||||
|
{7, ATTR_DESC(momentum, AnyTraits<float>())},
|
||||||
|
{8, ATTR_DESC(epsilon, AnyTraits<float>())}};
|
||||||
|
ATTR_MAP(ApplyRMSPropD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||||
|
OUTPUT_MAP(ApplyRMSPropD) = {{0, OUTPUT_DESC(var)}};
|
||||||
|
|
||||||
|
// ApplyCenteredRMSProp
|
||||||
|
INPUT_MAP(ApplyCenteredRMSProp) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(mg)}, {3, INPUT_DESC(ms)},
|
||||||
|
{4, INPUT_DESC(mom)}, {5, INPUT_DESC(grad)}, {6, INPUT_DESC(lr)},
|
||||||
|
{7, INPUT_DESC(rho)}, {8, INPUT_DESC(momentum)}, {9, INPUT_DESC(epsilon)}};
|
||||||
|
ATTR_MAP(ApplyCenteredRMSProp) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||||
|
OUTPUT_MAP(ApplyCenteredRMSProp) = {{0, OUTPUT_DESC(var)}};
|
||||||
|
|
||||||
#ifdef ENABLE_GE
|
#ifdef ENABLE_GE
|
||||||
// Print
|
// Print
|
||||||
INPUT_MAP(Print) = EMPTY_INPUT_MAP;
|
INPUT_MAP(Print) = EMPTY_INPUT_MAP;
|
||||||
|
|
|
@ -447,6 +447,12 @@ DECLARE_OP_ADAPTER(BatchToSpaceD)
|
||||||
DECLARE_OP_USE_OUTPUT(BatchToSpaceD)
|
DECLARE_OP_USE_OUTPUT(BatchToSpaceD)
|
||||||
DECLARE_OP_ADAPTER(Atan2)
|
DECLARE_OP_ADAPTER(Atan2)
|
||||||
DECLARE_OP_USE_OUTPUT(Atan2)
|
DECLARE_OP_USE_OUTPUT(Atan2)
|
||||||
|
DECLARE_OP_ADAPTER(ApplyRMSPropD)
|
||||||
|
DECLARE_OP_USE_INPUT_ATTR(ApplyRMSPropD)
|
||||||
|
DECLARE_OP_USE_OUTPUT(ApplyRMSPropD)
|
||||||
|
DECLARE_OP_ADAPTER(ApplyCenteredRMSProp)
|
||||||
|
DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSProp)
|
||||||
|
|
||||||
#ifdef ENABLE_GE
|
#ifdef ENABLE_GE
|
||||||
DECLARE_OP_ADAPTER(Print)
|
DECLARE_OP_ADAPTER(Print)
|
||||||
DECLARE_OP_USE_DYN_INPUT(Print)
|
DECLARE_OP_USE_DYN_INPUT(Print)
|
||||||
|
|
|
@ -25,6 +25,7 @@ from .lamb import Lamb
|
||||||
from .sgd import SGD
|
from .sgd import SGD
|
||||||
from .lars import LARS
|
from .lars import LARS
|
||||||
from .ftrl import FTRL
|
from .ftrl import FTRL
|
||||||
|
from .rmsprop import RMSProp
|
||||||
|
|
||||||
__all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'AdamWeightDecay',
|
__all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'AdamWeightDecay',
|
||||||
'AdamWeightDecayDynamicLR', 'Lamb', 'SGD', 'FTRL']
|
'AdamWeightDecayDynamicLR', 'Lamb', 'SGD', 'FTRL', 'RMSProp']
|
||||||
|
|
|
@ -0,0 +1,187 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""rmsprop"""
|
||||||
|
from mindspore.ops import functional as F, composite as C, operations as P
|
||||||
|
from mindspore.common.initializer import initializer
|
||||||
|
from mindspore.common.parameter import Parameter
|
||||||
|
from mindspore._checkparam import ParamValidator as validator
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from .optimizer import Optimizer, grad_scale
|
||||||
|
|
||||||
|
rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
|
||||||
|
centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
|
||||||
|
|
||||||
|
|
||||||
|
@rmsprop_opt.register("Function", "Number", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor")
|
||||||
|
def _rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, ms, mom, grad):
|
||||||
|
"""Apply rmsprop optimizer to the weight parameter."""
|
||||||
|
success = True
|
||||||
|
success = F.depend(success, opt(weight, ms, mom, grad, learning_rate, decay, momentum, epsilon))
|
||||||
|
return success
|
||||||
|
|
||||||
|
|
||||||
|
@rmsprop_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor")
|
||||||
|
def _rmsprop_opt_dynamic_lr(opt, learning_rate, decay, epsilon, momentum, weight, ms, mom, grad):
|
||||||
|
"""Apply rmsprop optimizer to the weight parameter using dynamic learning rate."""
|
||||||
|
success = True
|
||||||
|
success = F.depend(success, opt(weight, ms, mom, grad, learning_rate, decay, momentum, epsilon))
|
||||||
|
return success
|
||||||
|
|
||||||
|
|
||||||
|
@centered_rmsprop_opt.register("Function", "Number", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor",
|
||||||
|
"Tensor", "Tensor")
|
||||||
|
def _centered_rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, mg, ms, mom, grad):
|
||||||
|
"""Apply centered rmsprop optimizer to the weight parameter."""
|
||||||
|
success = True
|
||||||
|
success = F.depend(success, opt(weight, mg, ms, mom, grad, learning_rate, decay, momentum, epsilon))
|
||||||
|
return success
|
||||||
|
|
||||||
|
|
||||||
|
@centered_rmsprop_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor",
|
||||||
|
"Tensor", "Tensor")
|
||||||
|
def _centered_rmsprop_opt_dynamic_lr(opt, learning_rate, decay, epsilon, momentum, weight, mg, ms, mom, grad):
|
||||||
|
"""Apply centered rmsprop optimizer to the weight parameter using dynamic learning rate."""
|
||||||
|
success = True
|
||||||
|
success = F.depend(success, opt(weight, mg, ms, mom, grad, learning_rate, decay, momentum, epsilon))
|
||||||
|
return success
|
||||||
|
|
||||||
|
|
||||||
|
class RMSProp(Optimizer):
|
||||||
|
"""
|
||||||
|
Implements Root Mean Squared Propagation (RMSProp) algorithm.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Update `params` according to the RMSProp algorithm.
|
||||||
|
|
||||||
|
The equation is as follows:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
s_{t} = \\rho s_{t-1} + (1 - \\rho)(\\nabla Q_{i}(w))^2
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
m_{t} = \\beta m_{t-1} + \\frac{\\eta} {\\sqrt{s_{t} + \\epsilon}} \\nabla Q_{i}(w)
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
w = w - m_{t}
|
||||||
|
|
||||||
|
The first equation calculates moving average of the squared gradient for
|
||||||
|
each weight. Then dividing the gradient by :math:`\\sqrt{ms_{t} + \\epsilon}`.
|
||||||
|
|
||||||
|
if centered is True:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
g_{t} = \\rho g_{t-1} + (1 - \\rho)\\nabla Q_{i}(w)
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
s_{t} = \\rho s_{t-1} + (1 - \\rho)(\\nabla Q_{i}(w))^2
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
m_{t} = \\beta m_{t-1} + \\frac{\\eta} {\\sqrt{s_{t} - g_{t}^2 + \\epsilon}} \\nabla Q_{i}(w)
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
w = w - m_{t}
|
||||||
|
|
||||||
|
where, :math:`w` represents `params`, which will be updated.
|
||||||
|
:math:`g_{t}` is mean gradients, :math:`g_{t-1}` is the last moment of :math:`g_{t}`.
|
||||||
|
:math:`s_{t}` is the mean square gradients, :math:`s_{t-1}` is the last moment of :math:`s_{t}`,
|
||||||
|
:math:`m_{t}` is moment, the delta of `w`, :math:`m_{t-1}` is the last moment of :math:`m_{t}`.
|
||||||
|
:math:`\\rho` represents `decay`. :math:`\\beta` is the momentum term, represents `momentum`.
|
||||||
|
:math:`\\epsilon` is a smoothing term to avoid division by zero, represents `epsilon`.
|
||||||
|
:math:`\\eta` is learning rate, represents `learning_rate`. :math:`\\nabla Q_{i}(w)` is gradientse,
|
||||||
|
represents `gradients`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params (list[Parameter]): A list of parameter, which will be updated. The element in `parameters`
|
||||||
|
should be class mindspore.Parameter.
|
||||||
|
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
|
||||||
|
Iterable or a Tensor and the dims of the Tensor is 1,
|
||||||
|
use dynamic learning rate, then the i-th step will
|
||||||
|
take the i-th value as the learning rate.
|
||||||
|
When the learning_rate is float or learning_rate is a Tensor
|
||||||
|
but the dims of the Tensor is 0, use fixed learning rate.
|
||||||
|
Other cases are not supported.
|
||||||
|
decay (float): Decay rate.
|
||||||
|
momentum (float): Hyperparameter of type float, means momentum for the moving average.
|
||||||
|
epsilon (float): Term added to the denominator to improve numerical stability. Should be greater than 0.
|
||||||
|
use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False.
|
||||||
|
centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False
|
||||||
|
loss_scale (float): A floating point value for the loss scale. Default: 1.0.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor[bool], the value is True.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> net = Net()
|
||||||
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||||
|
>>> opt = RMSProp(params=net.trainable_params(), learning_rate=lr)
|
||||||
|
>>> model = Model(net, loss, opt)
|
||||||
|
"""
|
||||||
|
def __init__(self, params, learning_rate=0.1, decay=0.9, momentum=0.0, epsilon=1e-10,
|
||||||
|
use_locking=False, centered=False, loss_scale=1.0):
|
||||||
|
super(RMSProp, self).__init__(learning_rate, params)
|
||||||
|
|
||||||
|
if isinstance(momentum, float) and momentum < 0.0:
|
||||||
|
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
|
||||||
|
|
||||||
|
if decay < 0.0:
|
||||||
|
raise ValueError("decay should be at least 0.0, but got dampening {}".format(decay))
|
||||||
|
self.decay = decay
|
||||||
|
self.epsilon = epsilon
|
||||||
|
|
||||||
|
validator.check_type("use_locking", use_locking, [bool])
|
||||||
|
validator.check_type("centered", centered, [bool])
|
||||||
|
self.centered = centered
|
||||||
|
if centered:
|
||||||
|
self.opt = P.ApplyCenteredRMSProp(use_locking)
|
||||||
|
self.mg = self.parameters.clone(prefix="mean_grad", init='zeros')
|
||||||
|
else:
|
||||||
|
self.opt = P.ApplyRMSProp(use_locking)
|
||||||
|
|
||||||
|
self.dynamic_lr = False
|
||||||
|
if not isinstance(learning_rate, float):
|
||||||
|
self.dynamic_lr = True
|
||||||
|
self.gather = P.GatherV2()
|
||||||
|
self.assignadd = P.AssignAdd()
|
||||||
|
self.global_step = Parameter(initializer(0, [1], mstype.int32), name="global_step")
|
||||||
|
self.axis = 0
|
||||||
|
|
||||||
|
self.momentum = momentum
|
||||||
|
|
||||||
|
self.ms = self.parameters.clone(prefix="mean_square", init='zeros')
|
||||||
|
self.moment = self.parameters.clone(prefix="moment", init='zeros')
|
||||||
|
self.hyper_map = C.HyperMap()
|
||||||
|
|
||||||
|
self.decay = decay
|
||||||
|
self.reciprocal_scale = 1.0 / loss_scale
|
||||||
|
|
||||||
|
def construct(self, gradients):
|
||||||
|
params = self.parameters
|
||||||
|
if self.reciprocal_scale != 1.0:
|
||||||
|
gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)
|
||||||
|
if self.dynamic_lr:
|
||||||
|
lr = self.gather(self.learning_rate, self.global_step, self.axis)
|
||||||
|
F.control_depend(lr, self.assignadd(self.global_step, self.one))
|
||||||
|
else:
|
||||||
|
lr = self.learning_rate
|
||||||
|
if self.centered:
|
||||||
|
success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, lr, self.decay, self.epsilon,
|
||||||
|
self.momentum), params, self.mg, self.ms, self.moment, gradients)
|
||||||
|
else:
|
||||||
|
success = self.hyper_map(F.partial(rmsprop_opt, self.opt, lr, self.decay, self.epsilon,
|
||||||
|
self.momentum), params, self.ms, self.moment, gradients)
|
||||||
|
return success
|
|
@ -394,8 +394,8 @@ def _split_shape_index(input_shape, axis):
|
||||||
axis = tuple([axis])
|
axis = tuple([axis])
|
||||||
reduction_indices = tuple([(i + rank) % rank for i in axis])
|
reduction_indices = tuple([(i + rank) % rank for i in axis])
|
||||||
other_indices = tuple(set(range(rank)) - set(reduction_indices))
|
other_indices = tuple(set(range(rank)) - set(reduction_indices))
|
||||||
reduced_num = reduce(lambda x, y: x * y, [input_shape[i] for i in reduction_indices])
|
reduced_num = reduce(lambda x, y: x * y, [1] + [input_shape[i] for i in reduction_indices])
|
||||||
other_num = reduce(lambda x, y: x * y, [input_shape[i] for i in other_indices])
|
other_num = reduce(lambda x, y: x * y, [1] + [input_shape[i] for i in other_indices])
|
||||||
perm = reduction_indices + other_indices
|
perm = reduction_indices + other_indices
|
||||||
return tuple([reduced_num, other_num]), perm
|
return tuple([reduced_num, other_num]), perm
|
||||||
|
|
||||||
|
|
|
@ -65,7 +65,8 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
|
||||||
SmoothL1Loss, Softmax,
|
SmoothL1Loss, Softmax,
|
||||||
SoftmaxCrossEntropyWithLogits, ROIAlign,
|
SoftmaxCrossEntropyWithLogits, ROIAlign,
|
||||||
SparseSoftmaxCrossEntropyWithLogits, Tanh,
|
SparseSoftmaxCrossEntropyWithLogits, Tanh,
|
||||||
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrlD)
|
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrlD,
|
||||||
|
ApplyRMSProp, ApplyCenteredRMSProp)
|
||||||
from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey
|
from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey
|
||||||
|
|
||||||
|
|
||||||
|
@ -229,6 +230,8 @@ __all__ = [
|
||||||
"SpaceToBatch",
|
"SpaceToBatch",
|
||||||
"BatchToSpace",
|
"BatchToSpace",
|
||||||
"Atan2",
|
"Atan2",
|
||||||
|
"ApplyRMSProp",
|
||||||
|
"ApplyCenteredRMSProp"
|
||||||
]
|
]
|
||||||
|
|
||||||
__all__.sort()
|
__all__.sort()
|
||||||
|
|
|
@ -1359,6 +1359,158 @@ class SGD(PrimitiveWithInfer):
|
||||||
validator.check_typename("stat_dtype", stat_dtype, [mstype.float16, mstype.float32])
|
validator.check_typename("stat_dtype", stat_dtype, [mstype.float16, mstype.float32])
|
||||||
return parameters_dtype
|
return parameters_dtype
|
||||||
|
|
||||||
|
class ApplyRMSProp(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Optimizer that implements the Root Mean Square prop(RMSProp) algorithm.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Update `var` according to the RMSProp algorithm.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
s_{t} = \\rho s_{t-1} + (1 - \\rho)(\\nabla Q_{i}(w))^2
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
m_{t} = \\beta m_{t-1} + \\frac{\\eta} {\\sqrt{s_{t} + \\epsilon}} \\nabla Q_{i}(w)
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
w = w - m_{t}
|
||||||
|
|
||||||
|
where, :math:`w` represents `var`, which will be updated.
|
||||||
|
:math:`s_{t}` represents `mean_square`, :math:`s_{t-1}` is the last momentent of :math:`s_{t}`,
|
||||||
|
:math:`m_{t}` represents `moment`, :math:`m_{t-1}` is the last momentent of :math:`m_{t}`.
|
||||||
|
:math:`\\rho` represents `decay`. :math:`\\beta` is the momentum term, represents `momentum`.
|
||||||
|
:math:`\\epsilon` is a smoothing term to avoid division by zero, represents `epsilon`.
|
||||||
|
:math:`\\eta` represents `learning_rate`. :math:`\\nabla Q_{i}(w)` represents `grad`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_locking (bool): Enable a lock to protect the update of variable tensors. Default: False.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **var** (Tensor) - Weights to be update.
|
||||||
|
- **mean_square** (Tensor) - Mean square gradients, must have the same type as `var`.
|
||||||
|
- **moment** (Tensor) - Delta of `var`, must have the same type as `var`.
|
||||||
|
- **grad** (Tensor) - Gradients, must have the same type as `var`.
|
||||||
|
- **learning_rate** (Union[Number, Tensor]) - Learning rate.
|
||||||
|
- **decay** (float) - Decay rate.
|
||||||
|
- **momentum** (float) - Momentum.
|
||||||
|
- **epsilon** (float) - Ridge term.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, parameters to be update.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> net = Net()
|
||||||
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||||
|
>>> opt = RMSProp(params=net.trainable_params(), learning_rate=learning_rate)
|
||||||
|
>>> model = Model(net, loss, opt)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, use_locking=False):
|
||||||
|
self.use_locking = validator.check_type("use_locking", use_locking, [bool])
|
||||||
|
|
||||||
|
def infer_shape(self, var_shape, mean_square_shape, moment_shape, grad_shape, learning_rate_shape, decay_shape,
|
||||||
|
momentum_shape, epsilon_shape):
|
||||||
|
validator.check_param_equal("var_shape", var_shape, "mean_square_shape", mean_square_shape)
|
||||||
|
validator.check_param_equal("var_shape", var_shape, "moment_shape", moment_shape)
|
||||||
|
validator.check_param_equal("var_shape", var_shape, "grad_shape", grad_shape)
|
||||||
|
return var_shape
|
||||||
|
|
||||||
|
def infer_dtype(self, var_dtype, mean_square_dtype, moment_dtype, grad_dtype, learning_rate_dtype, decay_dtype,
|
||||||
|
momentum_dtype, epsilon_dtype):
|
||||||
|
validator.check_subclass("var_dtype", var_dtype, mstype.tensor)
|
||||||
|
validator.check_subclass("mean_square_dtype", mean_square_dtype, mstype.tensor)
|
||||||
|
validator.check_subclass("moment_dtype", moment_dtype, mstype.tensor)
|
||||||
|
validator.check_subclass("grad_dtype", moment_dtype, mstype.tensor)
|
||||||
|
args = {"var_dtype": var_dtype, "mean_square_dtype": mean_square_dtype, "moment_dtype": moment_dtype,
|
||||||
|
"grad_dtype": grad_dtype}
|
||||||
|
validator.check_type_same(args, mstype.number_type)
|
||||||
|
|
||||||
|
args = {"learning_rate_dtype": learning_rate_dtype, "decay_dtype": decay_dtype,
|
||||||
|
'momentum_dtype': momentum_dtype, "epsilon_dtype": epsilon_dtype}
|
||||||
|
validator.check_type_same(args, [mstype.float16, mstype.float32])
|
||||||
|
return var_dtype
|
||||||
|
|
||||||
|
|
||||||
|
class ApplyCenteredRMSProp(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Optimizer that implements the centered RMSProp algorithm.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Update `var` according to the centered RMSProp algorithm.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
g_{t} = \\rho g_{t-1} + (1 - \\rho)\\nabla Q_{i}(w)
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
s_{t} = \\rho s_{t-1} + (1 - \\rho)(\\nabla Q_{i}(w))^2
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
m_{t} = \\beta m_{t-1} + \\frac{\\eta} {\\sqrt{s_{t} - g_{t}^2 + \\epsilon}} \\nabla Q_{i}(w)
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
w = w - m_{t}
|
||||||
|
|
||||||
|
where, :math:`w` represents `var`, which will be updated.
|
||||||
|
:math:`g_{t}` represents `mean_gradient`, :math:`g_{t-1}` is the last momentent of :math:`g_{t}`.
|
||||||
|
:math:`s_{t}` represents `mean_square`, :math:`s_{t-1}` is the last momentent of :math:`s_{t}`,
|
||||||
|
:math:`m_{t}` represents `moment`, :math:`m_{t-1}` is the last momentent of :math:`m_{t}`.
|
||||||
|
:math:`\\rho` represents `decay`. :math:`\\beta` is the momentum term, represents `momentum`.
|
||||||
|
:math:`\\epsilon` is a smoothing term to avoid division by zero, represents `epsilon`.
|
||||||
|
:math:`\\eta` represents `learning_rate`. :math:`\\nabla Q_{i}(w)` represents `grad`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_locking (bool): Enable a lock to protect the update of variable tensors. Default: False.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **var** (Tensor) - Weights to be update.
|
||||||
|
- **mean_gradient** (Tensor) - Mean gradients, must have the same type as `var`.
|
||||||
|
- **mean_square** (Tensor) - Mean square gradients, must have the same type as `var`.
|
||||||
|
- **moment** (Tensor) - Delta of `var`, must have the same type as `var`.
|
||||||
|
- **grad** (Tensor) - Gradients, must have the same type as `var`.
|
||||||
|
- **learning_rate** (Union[Number, Tensor]) - Learning rate.
|
||||||
|
- **decay** (float) - Decay rate.
|
||||||
|
- **momentum** (float) - Momentum.
|
||||||
|
- **epsilon** (float) - Ridge term.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, parameters to be update.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> net = Net()
|
||||||
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||||
|
>>> opt = RMSProp(params=net.trainable_params(), learning_rate=learning_rate, centered=True)
|
||||||
|
>>> model = Model(net, loss, opt)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, use_locking=False):
|
||||||
|
self.use_locking = validator.check_type("use_locking", use_locking, [bool])
|
||||||
|
|
||||||
|
def infer_shape(self, var_shape, mean_gradient_shape, mean_square_shape, moment_shape, grad_shape,
|
||||||
|
learning_rate_shape, decay_shape, momentum_shape, epsilon_shape):
|
||||||
|
validator.check_param_equal("var_shape", var_shape, "mean_gradient_shape", mean_gradient_shape)
|
||||||
|
validator.check_param_equal("var_shape", var_shape, "mean_square_shape", mean_square_shape)
|
||||||
|
validator.check_param_equal("var_shape", var_shape, "moment_shape", moment_shape)
|
||||||
|
validator.check_param_equal("var_shape", var_shape, "grad_shape", grad_shape)
|
||||||
|
return var_shape
|
||||||
|
|
||||||
|
def infer_dtype(self, var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype, grad_dtype,
|
||||||
|
learning_rate_dtype, rho_dtype, momentum_dtype, epsilon_dtype):
|
||||||
|
validator.check_subclass("var_dtype", var_dtype, mstype.tensor)
|
||||||
|
validator.check_subclass("mean_gradient_dtype", mean_gradient_dtype, mstype.tensor)
|
||||||
|
validator.check_subclass("mean_square_dtype", mean_square_dtype, mstype.tensor)
|
||||||
|
validator.check_subclass("moment_dtype", moment_dtype, mstype.tensor)
|
||||||
|
validator.check_subclass("grad_dtype", moment_dtype, mstype.tensor)
|
||||||
|
args = {"var_dtype": var_dtype, "mean_gradient_dtype": mean_gradient_dtype,
|
||||||
|
"mean_square_dtype": mean_square_dtype, "moment_dtype": moment_dtype, "grad_dtype": grad_dtype}
|
||||||
|
validator.check_type_same(args, mstype.number_type)
|
||||||
|
|
||||||
|
args = {"learning_rate_dtype": learning_rate_dtype, "rho_dtype": rho_dtype, 'momentum_dtype': momentum_dtype,
|
||||||
|
"epsilon_dtype": epsilon_dtype}
|
||||||
|
validator.check_type_same(args, [mstype.float16, mstype.float32])
|
||||||
|
return var_dtype
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(Primitive):
|
class LayerNorm(Primitive):
|
||||||
r"""
|
r"""
|
||||||
|
|
|
@ -223,6 +223,10 @@ class InputOpNet(nn.Cell):
|
||||||
x = self.op(x1, x2, x3, x4, x5, self.c1)
|
x = self.op(x1, x2, x3, x4, x5, self.c1)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def construct5_c4(self, x1, x2, x3, x4, x5):
|
||||||
|
x = self.op(x1, x2, x3, x4, x5, self.c1, self.c2, self.c3, self.c4)
|
||||||
|
return x
|
||||||
|
|
||||||
def gen_net(op, input_num, training=True, desc_const=(), const_first=False, add_fake_input=False):
|
def gen_net(op, input_num, training=True, desc_const=(), const_first=False, add_fake_input=False):
|
||||||
if isinstance(op, nn.Cell):
|
if isinstance(op, nn.Cell):
|
||||||
return op
|
return op
|
||||||
|
|
|
@ -810,6 +810,18 @@ test_case_nn_ops = [
|
||||||
'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3]],
|
'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3]],
|
||||||
'desc_bprop': [3, 3],
|
'desc_bprop': [3, 3],
|
||||||
'skip': ['backward']}),
|
'skip': ['backward']}),
|
||||||
|
('ApplyRMSProp', {
|
||||||
|
'block': P.ApplyRMSProp(),
|
||||||
|
'desc_const': [0.9, 0.0, 1e-10, 0.001],
|
||||||
|
'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3]],
|
||||||
|
'desc_bprop': [3, 3],
|
||||||
|
'skip': ['backward']}),
|
||||||
|
('ApplyCenteredRMSProp', {
|
||||||
|
'block': P.ApplyCenteredRMSProp(),
|
||||||
|
'desc_const': [0.9, 0.0, 1e-10, 0.001],
|
||||||
|
'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3]],
|
||||||
|
'desc_bprop': [3, 3],
|
||||||
|
'skip': ['backward']}),
|
||||||
]
|
]
|
||||||
|
|
||||||
test_case_array_ops = [
|
test_case_array_ops = [
|
||||||
|
|
Loading…
Reference in New Issue