add RMSProp optimizer

This commit is contained in:
zhaoting 2020-03-31 09:14:08 +08:00
parent 5c22c088bb
commit ed3c2d7229
10 changed files with 390 additions and 5 deletions

View File

@ -183,6 +183,8 @@ const char kNameDiagPart[] = "DiagPart";
const char kNameSpaceToBatch[] = "SpaceToBatch"; const char kNameSpaceToBatch[] = "SpaceToBatch";
const char kNameBatchToSpace[] = "BatchToSpace"; const char kNameBatchToSpace[] = "BatchToSpace";
const char kNameAtan2[] = "Atan2"; const char kNameAtan2[] = "Atan2";
const char kNameApplyRMSProp[] = "ApplyRMSProp";
const char kNameApplyCenteredRMSProp[] = "ApplyCenteredRMSProp";
// -----------------OpAdapter initialization-------------- // -----------------OpAdapter initialization--------------
std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_map() { std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_map() {
@ -367,7 +369,9 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameDiagPart), ADPT_DESC(DiagPart)}, {string(kNameDiagPart), ADPT_DESC(DiagPart)},
{string(kNameSpaceToBatch), ADPT_DESC(SpaceToBatchD)}, {string(kNameSpaceToBatch), ADPT_DESC(SpaceToBatchD)},
{string(kNameBatchToSpace), ADPT_DESC(BatchToSpaceD)}, {string(kNameBatchToSpace), ADPT_DESC(BatchToSpaceD)},
{string(kNameAtan2), ADPT_DESC(Atan2)}}; {string(kNameAtan2), ADPT_DESC(Atan2)},
{string(kNameApplyRMSProp), ADPT_DESC(ApplyRMSPropD)},
{string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)}};
#ifdef ENABLE_GE #ifdef ENABLE_GE
adpt_map[string(kNamePrint)] = ADPT_DESC(Print); adpt_map[string(kNamePrint)] = ADPT_DESC(Print);
#endif #endif

View File

@ -1202,6 +1202,22 @@ INPUT_MAP(Atan2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
ATTR_MAP(Atan2) = EMPTY_ATTR_MAP; ATTR_MAP(Atan2) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Atan2) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(Atan2) = {{0, OUTPUT_DESC(y)}};
// ApplyRMSPropD
INPUT_MAP(ApplyRMSPropD) = {
{1, INPUT_DESC(var)}, {2, INPUT_DESC(ms)}, {3, INPUT_DESC(mom)}, {4, INPUT_DESC(grad)}, {5, INPUT_DESC(lr)}};
INPUT_ATTR_MAP(ApplyRMSPropD) = {{6, ATTR_DESC(rho, AnyTraits<float>())},
{7, ATTR_DESC(momentum, AnyTraits<float>())},
{8, ATTR_DESC(epsilon, AnyTraits<float>())}};
ATTR_MAP(ApplyRMSPropD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyRMSPropD) = {{0, OUTPUT_DESC(var)}};
// ApplyCenteredRMSProp
INPUT_MAP(ApplyCenteredRMSProp) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(mg)}, {3, INPUT_DESC(ms)},
{4, INPUT_DESC(mom)}, {5, INPUT_DESC(grad)}, {6, INPUT_DESC(lr)},
{7, INPUT_DESC(rho)}, {8, INPUT_DESC(momentum)}, {9, INPUT_DESC(epsilon)}};
ATTR_MAP(ApplyCenteredRMSProp) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyCenteredRMSProp) = {{0, OUTPUT_DESC(var)}};
#ifdef ENABLE_GE #ifdef ENABLE_GE
// Print // Print
INPUT_MAP(Print) = EMPTY_INPUT_MAP; INPUT_MAP(Print) = EMPTY_INPUT_MAP;

View File

@ -445,6 +445,12 @@ DECLARE_OP_ADAPTER(BatchToSpaceD)
DECLARE_OP_USE_OUTPUT(BatchToSpaceD) DECLARE_OP_USE_OUTPUT(BatchToSpaceD)
DECLARE_OP_ADAPTER(Atan2) DECLARE_OP_ADAPTER(Atan2)
DECLARE_OP_USE_OUTPUT(Atan2) DECLARE_OP_USE_OUTPUT(Atan2)
DECLARE_OP_ADAPTER(ApplyRMSPropD)
DECLARE_OP_USE_INPUT_ATTR(ApplyRMSPropD)
DECLARE_OP_USE_OUTPUT(ApplyRMSPropD)
DECLARE_OP_ADAPTER(ApplyCenteredRMSProp)
DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSProp)
#ifdef ENABLE_GE #ifdef ENABLE_GE
DECLARE_OP_ADAPTER(Print) DECLARE_OP_ADAPTER(Print)
DECLARE_OP_USE_DYN_INPUT(Print) DECLARE_OP_USE_DYN_INPUT(Print)

View File

@ -25,6 +25,7 @@ from .lamb import Lamb
from .sgd import SGD from .sgd import SGD
from .lars import LARS from .lars import LARS
from .ftrl import FTRL from .ftrl import FTRL
from .rmsprop import RMSProp
__all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'AdamWeightDecay', __all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'AdamWeightDecay',
'AdamWeightDecayDynamicLR', 'Lamb', 'SGD', 'FTRL'] 'AdamWeightDecayDynamicLR', 'Lamb', 'SGD', 'FTRL', 'RMSProp']

View File

@ -0,0 +1,187 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""rmsprop"""
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore._checkparam import ParamValidator as validator
import mindspore.common.dtype as mstype
from .optimizer import Optimizer, grad_scale
rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
@rmsprop_opt.register("Function", "Number", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor")
def _rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, ms, mom, grad):
"""Apply rmsprop optimizer to the weight parameter."""
success = True
success = F.depend(success, opt(weight, ms, mom, grad, learning_rate, decay, momentum, epsilon))
return success
@rmsprop_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor")
def _rmsprop_opt_dynamic_lr(opt, learning_rate, decay, epsilon, momentum, weight, ms, mom, grad):
"""Apply rmsprop optimizer to the weight parameter using dynamic learning rate."""
success = True
success = F.depend(success, opt(weight, ms, mom, grad, learning_rate, decay, momentum, epsilon))
return success
@centered_rmsprop_opt.register("Function", "Number", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor")
def _centered_rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, mg, ms, mom, grad):
"""Apply centered rmsprop optimizer to the weight parameter."""
success = True
success = F.depend(success, opt(weight, mg, ms, mom, grad, learning_rate, decay, momentum, epsilon))
return success
@centered_rmsprop_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor")
def _centered_rmsprop_opt_dynamic_lr(opt, learning_rate, decay, epsilon, momentum, weight, mg, ms, mom, grad):
"""Apply centered rmsprop optimizer to the weight parameter using dynamic learning rate."""
success = True
success = F.depend(success, opt(weight, mg, ms, mom, grad, learning_rate, decay, momentum, epsilon))
return success
class RMSProp(Optimizer):
"""
Implements Root Mean Squared Propagation (RMSProp) algorithm.
Note:
Update `params` according to the RMSProp algorithm.
The equation is as follows:
.. math::
s_{t} = \\rho s_{t-1} + (1 - \\rho)(\\nabla Q_{i}(w))^2
.. math::
m_{t} = \\beta m_{t-1} + \\frac{\\eta} {\\sqrt{s_{t} + \\epsilon}} \\nabla Q_{i}(w)
.. math::
w = w - m_{t}
The first equation calculates moving average of the squared gradient for
each weight. Then dividing the gradient by :math:`\\sqrt{ms_{t} + \\epsilon}`.
if centered is True:
.. math::
g_{t} = \\rho g_{t-1} + (1 - \\rho)\\nabla Q_{i}(w)
.. math::
s_{t} = \\rho s_{t-1} + (1 - \\rho)(\\nabla Q_{i}(w))^2
.. math::
m_{t} = \\beta m_{t-1} + \\frac{\\eta} {\\sqrt{s_{t} - g_{t}^2 + \\epsilon}} \\nabla Q_{i}(w)
.. math::
w = w - m_{t}
where, :math:`w` represents `params`, which will be updated.
:math:`g_{t}` is mean gradients, :math:`g_{t-1}` is the last moment of :math:`g_{t}`.
:math:`s_{t}` is the mean square gradients, :math:`s_{t-1}` is the last moment of :math:`s_{t}`,
:math:`m_{t}` is moment, the delta of `w`, :math:`m_{t-1}` is the last moment of :math:`m_{t}`.
:math:`\\rho` represents `decay`. :math:`\\beta` is the momentum term, represents `momentum`.
:math:`\\epsilon` is a smoothing term to avoid division by zero, represents `epsilon`.
:math:`\\eta` is learning rate, represents `learning_rate`. :math:`\\nabla Q_{i}(w)` is gradientse,
represents `gradients`.
Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `parameters`
should be class mindspore.Parameter.
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
Iterable or a Tensor and the dims of the Tensor is 1,
use dynamic learning rate, then the i-th step will
take the i-th value as the learning rate.
When the learning_rate is float or learning_rate is a Tensor
but the dims of the Tensor is 0, use fixed learning rate.
Other cases are not supported.
decay (float): Decay rate.
momentum (float): Hyperparameter of type float, means momentum for the moving average.
epsilon (float): Term added to the denominator to improve numerical stability. Should be greater than 0.
use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False.
centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False
loss_scale (float): A floating point value for the loss scale. Default: 1.0.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs:
Tensor[bool], the value is True.
Examples:
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> opt = RMSProp(params=net.trainable_params(), learning_rate=lr)
>>> model = Model(net, loss, opt)
"""
def __init__(self, params, learning_rate=0.1, decay=0.9, momentum=0.0, epsilon=1e-10,
use_locking=False, centered=False, loss_scale=1.0):
super(RMSProp, self).__init__(learning_rate, params)
if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
if decay < 0.0:
raise ValueError("decay should be at least 0.0, but got dampening {}".format(decay))
self.decay = decay
self.epsilon = epsilon
validator.check_type("use_locking", use_locking, [bool])
validator.check_type("centered", centered, [bool])
self.centered = centered
if centered:
self.opt = P.ApplyCenteredRMSProp(use_locking)
self.mg = self.parameters.clone(prefix="mean_grad", init='zeros')
else:
self.opt = P.ApplyRMSProp(use_locking)
self.dynamic_lr = False
if not isinstance(learning_rate, float):
self.dynamic_lr = True
self.gather = P.GatherV2()
self.assignadd = P.AssignAdd()
self.global_step = Parameter(initializer(0, [1], mstype.int32), name="global_step")
self.axis = 0
self.momentum = momentum
self.ms = self.parameters.clone(prefix="mean_square", init='zeros')
self.moment = self.parameters.clone(prefix="moment", init='zeros')
self.hyper_map = C.HyperMap()
self.decay = decay
self.reciprocal_scale = 1.0 / loss_scale
def construct(self, gradients):
params = self.parameters
if self.reciprocal_scale != 1.0:
gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)
if self.dynamic_lr:
lr = self.gather(self.learning_rate, self.global_step, self.axis)
F.control_depend(lr, self.assignadd(self.global_step, self.one))
else:
lr = self.learning_rate
if self.centered:
success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, lr, self.decay, self.epsilon,
self.momentum), params, self.mg, self.ms, self.moment, gradients)
else:
success = self.hyper_map(F.partial(rmsprop_opt, self.opt, lr, self.decay, self.epsilon,
self.momentum), params, self.ms, self.moment, gradients)
return success

View File

@ -394,8 +394,8 @@ def _split_shape_index(input_shape, axis):
axis = tuple([axis]) axis = tuple([axis])
reduction_indices = tuple([(i + rank) % rank for i in axis]) reduction_indices = tuple([(i + rank) % rank for i in axis])
other_indices = tuple(set(range(rank)) - set(reduction_indices)) other_indices = tuple(set(range(rank)) - set(reduction_indices))
reduced_num = reduce(lambda x, y: x * y, [input_shape[i] for i in reduction_indices]) reduced_num = reduce(lambda x, y: x * y, [1] + [input_shape[i] for i in reduction_indices])
other_num = reduce(lambda x, y: x * y, [input_shape[i] for i in other_indices]) other_num = reduce(lambda x, y: x * y, [1] + [input_shape[i] for i in other_indices])
perm = reduction_indices + other_indices perm = reduction_indices + other_indices
return tuple([reduced_num, other_num]), perm return tuple([reduced_num, other_num]), perm

View File

@ -65,7 +65,8 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
SmoothL1Loss, Softmax, SmoothL1Loss, Softmax,
SoftmaxCrossEntropyWithLogits, ROIAlign, SoftmaxCrossEntropyWithLogits, ROIAlign,
SparseSoftmaxCrossEntropyWithLogits, Tanh, SparseSoftmaxCrossEntropyWithLogits, Tanh,
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl) TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl,
ApplyRMSProp, ApplyCenteredRMSProp)
from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey
@ -228,6 +229,8 @@ __all__ = [
"SpaceToBatch", "SpaceToBatch",
"BatchToSpace", "BatchToSpace",
"Atan2", "Atan2",
"ApplyRMSProp",
"ApplyCenteredRMSProp"
] ]
__all__.sort() __all__.sort()

View File

@ -1359,6 +1359,158 @@ class SGD(PrimitiveWithInfer):
validator.check_typename("stat_dtype", stat_dtype, [mstype.float16, mstype.float32]) validator.check_typename("stat_dtype", stat_dtype, [mstype.float16, mstype.float32])
return parameters_dtype return parameters_dtype
class ApplyRMSProp(PrimitiveWithInfer):
"""
Optimizer that implements the Root Mean Square prop(RMSProp) algorithm.
Note:
Update `var` according to the RMSProp algorithm.
.. math::
s_{t} = \\rho s_{t-1} + (1 - \\rho)(\\nabla Q_{i}(w))^2
.. math::
m_{t} = \\beta m_{t-1} + \\frac{\\eta} {\\sqrt{s_{t} + \\epsilon}} \\nabla Q_{i}(w)
.. math::
w = w - m_{t}
where, :math:`w` represents `var`, which will be updated.
:math:`s_{t}` represents `mean_square`, :math:`s_{t-1}` is the last momentent of :math:`s_{t}`,
:math:`m_{t}` represents `moment`, :math:`m_{t-1}` is the last momentent of :math:`m_{t}`.
:math:`\\rho` represents `decay`. :math:`\\beta` is the momentum term, represents `momentum`.
:math:`\\epsilon` is a smoothing term to avoid division by zero, represents `epsilon`.
:math:`\\eta` represents `learning_rate`. :math:`\\nabla Q_{i}(w)` represents `grad`.
Args:
use_locking (bool): Enable a lock to protect the update of variable tensors. Default: False.
Inputs:
- **var** (Tensor) - Weights to be update.
- **mean_square** (Tensor) - Mean square gradients, must have the same type as `var`.
- **moment** (Tensor) - Delta of `var`, must have the same type as `var`.
- **grad** (Tensor) - Gradients, must have the same type as `var`.
- **learning_rate** (Union[Number, Tensor]) - Learning rate.
- **decay** (float) - Decay rate.
- **momentum** (float) - Momentum.
- **epsilon** (float) - Ridge term.
Outputs:
Tensor, parameters to be update.
Examples:
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> opt = RMSProp(params=net.trainable_params(), learning_rate=learning_rate)
>>> model = Model(net, loss, opt)
"""
@prim_attr_register
def __init__(self, use_locking=False):
self.use_locking = validator.check_type("use_locking", use_locking, [bool])
def infer_shape(self, var_shape, mean_square_shape, moment_shape, grad_shape, learning_rate_shape, decay_shape,
momentum_shape, epsilon_shape):
validator.check_param_equal("var_shape", var_shape, "mean_square_shape", mean_square_shape)
validator.check_param_equal("var_shape", var_shape, "moment_shape", moment_shape)
validator.check_param_equal("var_shape", var_shape, "grad_shape", grad_shape)
return var_shape
def infer_dtype(self, var_dtype, mean_square_dtype, moment_dtype, grad_dtype, learning_rate_dtype, decay_dtype,
momentum_dtype, epsilon_dtype):
validator.check_subclass("var_dtype", var_dtype, mstype.tensor)
validator.check_subclass("mean_square_dtype", mean_square_dtype, mstype.tensor)
validator.check_subclass("moment_dtype", moment_dtype, mstype.tensor)
validator.check_subclass("grad_dtype", moment_dtype, mstype.tensor)
args = {"var_dtype": var_dtype, "mean_square_dtype": mean_square_dtype, "moment_dtype": moment_dtype,
"grad_dtype": grad_dtype}
validator.check_type_same(args, mstype.number_type)
args = {"learning_rate_dtype": learning_rate_dtype, "decay_dtype": decay_dtype,
'momentum_dtype': momentum_dtype, "epsilon_dtype": epsilon_dtype}
validator.check_type_same(args, [mstype.float16, mstype.float32])
return var_dtype
class ApplyCenteredRMSProp(PrimitiveWithInfer):
"""
Optimizer that implements the centered RMSProp algorithm.
Note:
Update `var` according to the centered RMSProp algorithm.
.. math::
g_{t} = \\rho g_{t-1} + (1 - \\rho)\\nabla Q_{i}(w)
.. math::
s_{t} = \\rho s_{t-1} + (1 - \\rho)(\\nabla Q_{i}(w))^2
.. math::
m_{t} = \\beta m_{t-1} + \\frac{\\eta} {\\sqrt{s_{t} - g_{t}^2 + \\epsilon}} \\nabla Q_{i}(w)
.. math::
w = w - m_{t}
where, :math:`w` represents `var`, which will be updated.
:math:`g_{t}` represents `mean_gradient`, :math:`g_{t-1}` is the last momentent of :math:`g_{t}`.
:math:`s_{t}` represents `mean_square`, :math:`s_{t-1}` is the last momentent of :math:`s_{t}`,
:math:`m_{t}` represents `moment`, :math:`m_{t-1}` is the last momentent of :math:`m_{t}`.
:math:`\\rho` represents `decay`. :math:`\\beta` is the momentum term, represents `momentum`.
:math:`\\epsilon` is a smoothing term to avoid division by zero, represents `epsilon`.
:math:`\\eta` represents `learning_rate`. :math:`\\nabla Q_{i}(w)` represents `grad`.
Args:
use_locking (bool): Enable a lock to protect the update of variable tensors. Default: False.
Inputs:
- **var** (Tensor) - Weights to be update.
- **mean_gradient** (Tensor) - Mean gradients, must have the same type as `var`.
- **mean_square** (Tensor) - Mean square gradients, must have the same type as `var`.
- **moment** (Tensor) - Delta of `var`, must have the same type as `var`.
- **grad** (Tensor) - Gradients, must have the same type as `var`.
- **learning_rate** (Union[Number, Tensor]) - Learning rate.
- **decay** (float) - Decay rate.
- **momentum** (float) - Momentum.
- **epsilon** (float) - Ridge term.
Outputs:
Tensor, parameters to be update.
Examples:
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> opt = RMSProp(params=net.trainable_params(), learning_rate=learning_rate, centered=True)
>>> model = Model(net, loss, opt)
"""
@prim_attr_register
def __init__(self, use_locking=False):
self.use_locking = validator.check_type("use_locking", use_locking, [bool])
def infer_shape(self, var_shape, mean_gradient_shape, mean_square_shape, moment_shape, grad_shape,
learning_rate_shape, decay_shape, momentum_shape, epsilon_shape):
validator.check_param_equal("var_shape", var_shape, "mean_gradient_shape", mean_gradient_shape)
validator.check_param_equal("var_shape", var_shape, "mean_square_shape", mean_square_shape)
validator.check_param_equal("var_shape", var_shape, "moment_shape", moment_shape)
validator.check_param_equal("var_shape", var_shape, "grad_shape", grad_shape)
return var_shape
def infer_dtype(self, var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype, grad_dtype,
learning_rate_dtype, rho_dtype, momentum_dtype, epsilon_dtype):
validator.check_subclass("var_dtype", var_dtype, mstype.tensor)
validator.check_subclass("mean_gradient_dtype", mean_gradient_dtype, mstype.tensor)
validator.check_subclass("mean_square_dtype", mean_square_dtype, mstype.tensor)
validator.check_subclass("moment_dtype", moment_dtype, mstype.tensor)
validator.check_subclass("grad_dtype", moment_dtype, mstype.tensor)
args = {"var_dtype": var_dtype, "mean_gradient_dtype": mean_gradient_dtype,
"mean_square_dtype": mean_square_dtype, "moment_dtype": moment_dtype, "grad_dtype": grad_dtype}
validator.check_type_same(args, mstype.number_type)
args = {"learning_rate_dtype": learning_rate_dtype, "rho_dtype": rho_dtype, 'momentum_dtype': momentum_dtype,
"epsilon_dtype": epsilon_dtype}
validator.check_type_same(args, [mstype.float16, mstype.float32])
return var_dtype
class LayerNorm(Primitive): class LayerNorm(Primitive):
r""" r"""

View File

@ -223,6 +223,10 @@ class InputOpNet(nn.Cell):
x = self.op(x1, x2, x3, x4, x5, self.c1) x = self.op(x1, x2, x3, x4, x5, self.c1)
return x return x
def construct5_c4(self, x1, x2, x3, x4, x5):
x = self.op(x1, x2, x3, x4, x5, self.c1, self.c2, self.c3, self.c4)
return x
def gen_net(op, input_num, training=True, desc_const=(), const_first=False, add_fake_input=False): def gen_net(op, input_num, training=True, desc_const=(), const_first=False, add_fake_input=False):
if isinstance(op, nn.Cell): if isinstance(op, nn.Cell):
return op return op

View File

@ -805,6 +805,18 @@ test_case_nn_ops = [
'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3]], 'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3]],
'desc_bprop': [3, 3], 'desc_bprop': [3, 3],
'skip': ['backward']}), 'skip': ['backward']}),
('ApplyRMSProp', {
'block': P.ApplyRMSProp(),
'desc_const': [0.9, 0.0, 1e-10, 0.001],
'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3]],
'desc_bprop': [3, 3],
'skip': ['backward']}),
('ApplyCenteredRMSProp', {
'block': P.ApplyCenteredRMSProp(),
'desc_const': [0.9, 0.0, 1e-10, 0.001],
'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3]],
'desc_bprop': [3, 3],
'skip': ['backward']}),
] ]
test_case_array_ops = [ test_case_array_ops = [