From ed3c2d7229c1d3ca34e3ce95cb0a7acf2104564c Mon Sep 17 00:00:00 2001 From: zhaoting Date: Tue, 31 Mar 2020 09:14:08 +0800 Subject: [PATCH] add RMSProp optimizer --- mindspore/ccsrc/transform/convert.cc | 6 +- mindspore/ccsrc/transform/op_declare.cc | 16 ++ mindspore/ccsrc/transform/op_declare.h | 6 + mindspore/nn/optim/__init__.py | 3 +- mindspore/nn/optim/rmsprop.py | 187 ++++++++++++++++++ mindspore/ops/_grad/grad_math_ops.py | 4 +- mindspore/ops/operations/__init__.py | 5 +- mindspore/ops/operations/nn_ops.py | 152 ++++++++++++++ .../utils/block_util.py | 4 + tests/ut/python/ops/test_ops.py | 12 ++ 10 files changed, 390 insertions(+), 5 deletions(-) create mode 100644 mindspore/nn/optim/rmsprop.py diff --git a/mindspore/ccsrc/transform/convert.cc b/mindspore/ccsrc/transform/convert.cc index 48056c38dad..fdacff7ba8a 100755 --- a/mindspore/ccsrc/transform/convert.cc +++ b/mindspore/ccsrc/transform/convert.cc @@ -183,6 +183,8 @@ const char kNameDiagPart[] = "DiagPart"; const char kNameSpaceToBatch[] = "SpaceToBatch"; const char kNameBatchToSpace[] = "BatchToSpace"; const char kNameAtan2[] = "Atan2"; +const char kNameApplyRMSProp[] = "ApplyRMSProp"; +const char kNameApplyCenteredRMSProp[] = "ApplyCenteredRMSProp"; // -----------------OpAdapter initialization-------------- std::unordered_map &DfGraphConvertor::get_adpt_map() { @@ -367,7 +369,9 @@ std::unordered_map &DfGraphConvertor::get_adpt_ma {string(kNameDiagPart), ADPT_DESC(DiagPart)}, {string(kNameSpaceToBatch), ADPT_DESC(SpaceToBatchD)}, {string(kNameBatchToSpace), ADPT_DESC(BatchToSpaceD)}, - {string(kNameAtan2), ADPT_DESC(Atan2)}}; + {string(kNameAtan2), ADPT_DESC(Atan2)}, + {string(kNameApplyRMSProp), ADPT_DESC(ApplyRMSPropD)}, + {string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)}}; #ifdef ENABLE_GE adpt_map[string(kNamePrint)] = ADPT_DESC(Print); #endif diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index f7fdcfbe56d..9258eb08db2 100755 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -1202,6 +1202,22 @@ INPUT_MAP(Atan2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; ATTR_MAP(Atan2) = EMPTY_ATTR_MAP; OUTPUT_MAP(Atan2) = {{0, OUTPUT_DESC(y)}}; +// ApplyRMSPropD +INPUT_MAP(ApplyRMSPropD) = { + {1, INPUT_DESC(var)}, {2, INPUT_DESC(ms)}, {3, INPUT_DESC(mom)}, {4, INPUT_DESC(grad)}, {5, INPUT_DESC(lr)}}; +INPUT_ATTR_MAP(ApplyRMSPropD) = {{6, ATTR_DESC(rho, AnyTraits())}, + {7, ATTR_DESC(momentum, AnyTraits())}, + {8, ATTR_DESC(epsilon, AnyTraits())}}; +ATTR_MAP(ApplyRMSPropD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; +OUTPUT_MAP(ApplyRMSPropD) = {{0, OUTPUT_DESC(var)}}; + +// ApplyCenteredRMSProp +INPUT_MAP(ApplyCenteredRMSProp) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(mg)}, {3, INPUT_DESC(ms)}, + {4, INPUT_DESC(mom)}, {5, INPUT_DESC(grad)}, {6, INPUT_DESC(lr)}, + {7, INPUT_DESC(rho)}, {8, INPUT_DESC(momentum)}, {9, INPUT_DESC(epsilon)}}; +ATTR_MAP(ApplyCenteredRMSProp) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; +OUTPUT_MAP(ApplyCenteredRMSProp) = {{0, OUTPUT_DESC(var)}}; + #ifdef ENABLE_GE // Print INPUT_MAP(Print) = EMPTY_INPUT_MAP; diff --git a/mindspore/ccsrc/transform/op_declare.h b/mindspore/ccsrc/transform/op_declare.h index 1924d2719b0..031ce80865e 100755 --- a/mindspore/ccsrc/transform/op_declare.h +++ b/mindspore/ccsrc/transform/op_declare.h @@ -445,6 +445,12 @@ DECLARE_OP_ADAPTER(BatchToSpaceD) DECLARE_OP_USE_OUTPUT(BatchToSpaceD) DECLARE_OP_ADAPTER(Atan2) DECLARE_OP_USE_OUTPUT(Atan2) +DECLARE_OP_ADAPTER(ApplyRMSPropD) +DECLARE_OP_USE_INPUT_ATTR(ApplyRMSPropD) +DECLARE_OP_USE_OUTPUT(ApplyRMSPropD) +DECLARE_OP_ADAPTER(ApplyCenteredRMSProp) +DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSProp) + #ifdef ENABLE_GE DECLARE_OP_ADAPTER(Print) DECLARE_OP_USE_DYN_INPUT(Print) diff --git a/mindspore/nn/optim/__init__.py b/mindspore/nn/optim/__init__.py index 6f7f6fbd46e..8f211798934 100644 --- a/mindspore/nn/optim/__init__.py +++ b/mindspore/nn/optim/__init__.py @@ -25,6 +25,7 @@ from .lamb import Lamb from .sgd import SGD from .lars import LARS from .ftrl import FTRL +from .rmsprop import RMSProp __all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'AdamWeightDecay', - 'AdamWeightDecayDynamicLR', 'Lamb', 'SGD', 'FTRL'] + 'AdamWeightDecayDynamicLR', 'Lamb', 'SGD', 'FTRL', 'RMSProp'] diff --git a/mindspore/nn/optim/rmsprop.py b/mindspore/nn/optim/rmsprop.py new file mode 100644 index 00000000000..3000fdeeee3 --- /dev/null +++ b/mindspore/nn/optim/rmsprop.py @@ -0,0 +1,187 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""rmsprop""" +from mindspore.ops import functional as F, composite as C, operations as P +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter +from mindspore._checkparam import ParamValidator as validator +import mindspore.common.dtype as mstype +from .optimizer import Optimizer, grad_scale + +rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") +centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") + + +@rmsprop_opt.register("Function", "Number", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor") +def _rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, ms, mom, grad): + """Apply rmsprop optimizer to the weight parameter.""" + success = True + success = F.depend(success, opt(weight, ms, mom, grad, learning_rate, decay, momentum, epsilon)) + return success + + +@rmsprop_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor") +def _rmsprop_opt_dynamic_lr(opt, learning_rate, decay, epsilon, momentum, weight, ms, mom, grad): + """Apply rmsprop optimizer to the weight parameter using dynamic learning rate.""" + success = True + success = F.depend(success, opt(weight, ms, mom, grad, learning_rate, decay, momentum, epsilon)) + return success + + +@centered_rmsprop_opt.register("Function", "Number", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", + "Tensor", "Tensor") +def _centered_rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, mg, ms, mom, grad): + """Apply centered rmsprop optimizer to the weight parameter.""" + success = True + success = F.depend(success, opt(weight, mg, ms, mom, grad, learning_rate, decay, momentum, epsilon)) + return success + + +@centered_rmsprop_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", + "Tensor", "Tensor") +def _centered_rmsprop_opt_dynamic_lr(opt, learning_rate, decay, epsilon, momentum, weight, mg, ms, mom, grad): + """Apply centered rmsprop optimizer to the weight parameter using dynamic learning rate.""" + success = True + success = F.depend(success, opt(weight, mg, ms, mom, grad, learning_rate, decay, momentum, epsilon)) + return success + + +class RMSProp(Optimizer): + """ + Implements Root Mean Squared Propagation (RMSProp) algorithm. + + Note: + Update `params` according to the RMSProp algorithm. + + The equation is as follows: + + .. math:: + s_{t} = \\rho s_{t-1} + (1 - \\rho)(\\nabla Q_{i}(w))^2 + + .. math:: + m_{t} = \\beta m_{t-1} + \\frac{\\eta} {\\sqrt{s_{t} + \\epsilon}} \\nabla Q_{i}(w) + + .. math:: + w = w - m_{t} + + The first equation calculates moving average of the squared gradient for + each weight. Then dividing the gradient by :math:`\\sqrt{ms_{t} + \\epsilon}`. + + if centered is True: + + .. math:: + g_{t} = \\rho g_{t-1} + (1 - \\rho)\\nabla Q_{i}(w) + + .. math:: + s_{t} = \\rho s_{t-1} + (1 - \\rho)(\\nabla Q_{i}(w))^2 + + .. math:: + m_{t} = \\beta m_{t-1} + \\frac{\\eta} {\\sqrt{s_{t} - g_{t}^2 + \\epsilon}} \\nabla Q_{i}(w) + + .. math:: + w = w - m_{t} + + where, :math:`w` represents `params`, which will be updated. + :math:`g_{t}` is mean gradients, :math:`g_{t-1}` is the last moment of :math:`g_{t}`. + :math:`s_{t}` is the mean square gradients, :math:`s_{t-1}` is the last moment of :math:`s_{t}`, + :math:`m_{t}` is moment, the delta of `w`, :math:`m_{t-1}` is the last moment of :math:`m_{t}`. + :math:`\\rho` represents `decay`. :math:`\\beta` is the momentum term, represents `momentum`. + :math:`\\epsilon` is a smoothing term to avoid division by zero, represents `epsilon`. + :math:`\\eta` is learning rate, represents `learning_rate`. :math:`\\nabla Q_{i}(w)` is gradientse, + represents `gradients`. + + Args: + params (list[Parameter]): A list of parameter, which will be updated. The element in `parameters` + should be class mindspore.Parameter. + learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is + Iterable or a Tensor and the dims of the Tensor is 1, + use dynamic learning rate, then the i-th step will + take the i-th value as the learning rate. + When the learning_rate is float or learning_rate is a Tensor + but the dims of the Tensor is 0, use fixed learning rate. + Other cases are not supported. + decay (float): Decay rate. + momentum (float): Hyperparameter of type float, means momentum for the moving average. + epsilon (float): Term added to the denominator to improve numerical stability. Should be greater than 0. + use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False. + centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False + loss_scale (float): A floating point value for the loss scale. Default: 1.0. + + Inputs: + - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. + + Outputs: + Tensor[bool], the value is True. + + Examples: + >>> net = Net() + >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> opt = RMSProp(params=net.trainable_params(), learning_rate=lr) + >>> model = Model(net, loss, opt) + """ + def __init__(self, params, learning_rate=0.1, decay=0.9, momentum=0.0, epsilon=1e-10, + use_locking=False, centered=False, loss_scale=1.0): + super(RMSProp, self).__init__(learning_rate, params) + + if isinstance(momentum, float) and momentum < 0.0: + raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) + + if decay < 0.0: + raise ValueError("decay should be at least 0.0, but got dampening {}".format(decay)) + self.decay = decay + self.epsilon = epsilon + + validator.check_type("use_locking", use_locking, [bool]) + validator.check_type("centered", centered, [bool]) + self.centered = centered + if centered: + self.opt = P.ApplyCenteredRMSProp(use_locking) + self.mg = self.parameters.clone(prefix="mean_grad", init='zeros') + else: + self.opt = P.ApplyRMSProp(use_locking) + + self.dynamic_lr = False + if not isinstance(learning_rate, float): + self.dynamic_lr = True + self.gather = P.GatherV2() + self.assignadd = P.AssignAdd() + self.global_step = Parameter(initializer(0, [1], mstype.int32), name="global_step") + self.axis = 0 + + self.momentum = momentum + + self.ms = self.parameters.clone(prefix="mean_square", init='zeros') + self.moment = self.parameters.clone(prefix="moment", init='zeros') + self.hyper_map = C.HyperMap() + + self.decay = decay + self.reciprocal_scale = 1.0 / loss_scale + + def construct(self, gradients): + params = self.parameters + if self.reciprocal_scale != 1.0: + gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients) + if self.dynamic_lr: + lr = self.gather(self.learning_rate, self.global_step, self.axis) + F.control_depend(lr, self.assignadd(self.global_step, self.one)) + else: + lr = self.learning_rate + if self.centered: + success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, lr, self.decay, self.epsilon, + self.momentum), params, self.mg, self.ms, self.moment, gradients) + else: + success = self.hyper_map(F.partial(rmsprop_opt, self.opt, lr, self.decay, self.epsilon, + self.momentum), params, self.ms, self.moment, gradients) + return success diff --git a/mindspore/ops/_grad/grad_math_ops.py b/mindspore/ops/_grad/grad_math_ops.py index 9e90c5660cb..1675855c88d 100755 --- a/mindspore/ops/_grad/grad_math_ops.py +++ b/mindspore/ops/_grad/grad_math_ops.py @@ -394,8 +394,8 @@ def _split_shape_index(input_shape, axis): axis = tuple([axis]) reduction_indices = tuple([(i + rank) % rank for i in axis]) other_indices = tuple(set(range(rank)) - set(reduction_indices)) - reduced_num = reduce(lambda x, y: x * y, [input_shape[i] for i in reduction_indices]) - other_num = reduce(lambda x, y: x * y, [input_shape[i] for i in other_indices]) + reduced_num = reduce(lambda x, y: x * y, [1] + [input_shape[i] for i in reduction_indices]) + other_num = reduce(lambda x, y: x * y, [1] + [input_shape[i] for i in other_indices]) perm = reduction_indices + other_indices return tuple([reduced_num, other_num]), perm diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 5c98568b8a3..727ddaf88f0 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -65,7 +65,8 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, SmoothL1Loss, Softmax, SoftmaxCrossEntropyWithLogits, ROIAlign, SparseSoftmaxCrossEntropyWithLogits, Tanh, - TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl) + TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, + ApplyRMSProp, ApplyCenteredRMSProp) from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey @@ -228,6 +229,8 @@ __all__ = [ "SpaceToBatch", "BatchToSpace", "Atan2", + "ApplyRMSProp", + "ApplyCenteredRMSProp" ] __all__.sort() diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index afa4c7dfe38..b9ab7e8dc90 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1359,6 +1359,158 @@ class SGD(PrimitiveWithInfer): validator.check_typename("stat_dtype", stat_dtype, [mstype.float16, mstype.float32]) return parameters_dtype +class ApplyRMSProp(PrimitiveWithInfer): + """ + Optimizer that implements the Root Mean Square prop(RMSProp) algorithm. + + Note: + Update `var` according to the RMSProp algorithm. + + .. math:: + s_{t} = \\rho s_{t-1} + (1 - \\rho)(\\nabla Q_{i}(w))^2 + + .. math:: + m_{t} = \\beta m_{t-1} + \\frac{\\eta} {\\sqrt{s_{t} + \\epsilon}} \\nabla Q_{i}(w) + + .. math:: + w = w - m_{t} + + where, :math:`w` represents `var`, which will be updated. + :math:`s_{t}` represents `mean_square`, :math:`s_{t-1}` is the last momentent of :math:`s_{t}`, + :math:`m_{t}` represents `moment`, :math:`m_{t-1}` is the last momentent of :math:`m_{t}`. + :math:`\\rho` represents `decay`. :math:`\\beta` is the momentum term, represents `momentum`. + :math:`\\epsilon` is a smoothing term to avoid division by zero, represents `epsilon`. + :math:`\\eta` represents `learning_rate`. :math:`\\nabla Q_{i}(w)` represents `grad`. + + Args: + use_locking (bool): Enable a lock to protect the update of variable tensors. Default: False. + + Inputs: + - **var** (Tensor) - Weights to be update. + - **mean_square** (Tensor) - Mean square gradients, must have the same type as `var`. + - **moment** (Tensor) - Delta of `var`, must have the same type as `var`. + - **grad** (Tensor) - Gradients, must have the same type as `var`. + - **learning_rate** (Union[Number, Tensor]) - Learning rate. + - **decay** (float) - Decay rate. + - **momentum** (float) - Momentum. + - **epsilon** (float) - Ridge term. + + Outputs: + Tensor, parameters to be update. + + Examples: + >>> net = Net() + >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> opt = RMSProp(params=net.trainable_params(), learning_rate=learning_rate) + >>> model = Model(net, loss, opt) + """ + + @prim_attr_register + def __init__(self, use_locking=False): + self.use_locking = validator.check_type("use_locking", use_locking, [bool]) + + def infer_shape(self, var_shape, mean_square_shape, moment_shape, grad_shape, learning_rate_shape, decay_shape, + momentum_shape, epsilon_shape): + validator.check_param_equal("var_shape", var_shape, "mean_square_shape", mean_square_shape) + validator.check_param_equal("var_shape", var_shape, "moment_shape", moment_shape) + validator.check_param_equal("var_shape", var_shape, "grad_shape", grad_shape) + return var_shape + + def infer_dtype(self, var_dtype, mean_square_dtype, moment_dtype, grad_dtype, learning_rate_dtype, decay_dtype, + momentum_dtype, epsilon_dtype): + validator.check_subclass("var_dtype", var_dtype, mstype.tensor) + validator.check_subclass("mean_square_dtype", mean_square_dtype, mstype.tensor) + validator.check_subclass("moment_dtype", moment_dtype, mstype.tensor) + validator.check_subclass("grad_dtype", moment_dtype, mstype.tensor) + args = {"var_dtype": var_dtype, "mean_square_dtype": mean_square_dtype, "moment_dtype": moment_dtype, + "grad_dtype": grad_dtype} + validator.check_type_same(args, mstype.number_type) + + args = {"learning_rate_dtype": learning_rate_dtype, "decay_dtype": decay_dtype, + 'momentum_dtype': momentum_dtype, "epsilon_dtype": epsilon_dtype} + validator.check_type_same(args, [mstype.float16, mstype.float32]) + return var_dtype + + +class ApplyCenteredRMSProp(PrimitiveWithInfer): + """ + Optimizer that implements the centered RMSProp algorithm. + + Note: + Update `var` according to the centered RMSProp algorithm. + + .. math:: + g_{t} = \\rho g_{t-1} + (1 - \\rho)\\nabla Q_{i}(w) + + .. math:: + s_{t} = \\rho s_{t-1} + (1 - \\rho)(\\nabla Q_{i}(w))^2 + + .. math:: + m_{t} = \\beta m_{t-1} + \\frac{\\eta} {\\sqrt{s_{t} - g_{t}^2 + \\epsilon}} \\nabla Q_{i}(w) + + .. math:: + w = w - m_{t} + + where, :math:`w` represents `var`, which will be updated. + :math:`g_{t}` represents `mean_gradient`, :math:`g_{t-1}` is the last momentent of :math:`g_{t}`. + :math:`s_{t}` represents `mean_square`, :math:`s_{t-1}` is the last momentent of :math:`s_{t}`, + :math:`m_{t}` represents `moment`, :math:`m_{t-1}` is the last momentent of :math:`m_{t}`. + :math:`\\rho` represents `decay`. :math:`\\beta` is the momentum term, represents `momentum`. + :math:`\\epsilon` is a smoothing term to avoid division by zero, represents `epsilon`. + :math:`\\eta` represents `learning_rate`. :math:`\\nabla Q_{i}(w)` represents `grad`. + + Args: + use_locking (bool): Enable a lock to protect the update of variable tensors. Default: False. + + Inputs: + - **var** (Tensor) - Weights to be update. + - **mean_gradient** (Tensor) - Mean gradients, must have the same type as `var`. + - **mean_square** (Tensor) - Mean square gradients, must have the same type as `var`. + - **moment** (Tensor) - Delta of `var`, must have the same type as `var`. + - **grad** (Tensor) - Gradients, must have the same type as `var`. + - **learning_rate** (Union[Number, Tensor]) - Learning rate. + - **decay** (float) - Decay rate. + - **momentum** (float) - Momentum. + - **epsilon** (float) - Ridge term. + + Outputs: + Tensor, parameters to be update. + + Examples: + >>> net = Net() + >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> opt = RMSProp(params=net.trainable_params(), learning_rate=learning_rate, centered=True) + >>> model = Model(net, loss, opt) + """ + + @prim_attr_register + def __init__(self, use_locking=False): + self.use_locking = validator.check_type("use_locking", use_locking, [bool]) + + def infer_shape(self, var_shape, mean_gradient_shape, mean_square_shape, moment_shape, grad_shape, + learning_rate_shape, decay_shape, momentum_shape, epsilon_shape): + validator.check_param_equal("var_shape", var_shape, "mean_gradient_shape", mean_gradient_shape) + validator.check_param_equal("var_shape", var_shape, "mean_square_shape", mean_square_shape) + validator.check_param_equal("var_shape", var_shape, "moment_shape", moment_shape) + validator.check_param_equal("var_shape", var_shape, "grad_shape", grad_shape) + return var_shape + + def infer_dtype(self, var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype, grad_dtype, + learning_rate_dtype, rho_dtype, momentum_dtype, epsilon_dtype): + validator.check_subclass("var_dtype", var_dtype, mstype.tensor) + validator.check_subclass("mean_gradient_dtype", mean_gradient_dtype, mstype.tensor) + validator.check_subclass("mean_square_dtype", mean_square_dtype, mstype.tensor) + validator.check_subclass("moment_dtype", moment_dtype, mstype.tensor) + validator.check_subclass("grad_dtype", moment_dtype, mstype.tensor) + args = {"var_dtype": var_dtype, "mean_gradient_dtype": mean_gradient_dtype, + "mean_square_dtype": mean_square_dtype, "moment_dtype": moment_dtype, "grad_dtype": grad_dtype} + validator.check_type_same(args, mstype.number_type) + + args = {"learning_rate_dtype": learning_rate_dtype, "rho_dtype": rho_dtype, 'momentum_dtype': momentum_dtype, + "epsilon_dtype": epsilon_dtype} + validator.check_type_same(args, [mstype.float16, mstype.float32]) + return var_dtype + class LayerNorm(Primitive): r""" diff --git a/tests/mindspore_test_framework/utils/block_util.py b/tests/mindspore_test_framework/utils/block_util.py index 9d75ae0888a..b4a926c15d5 100644 --- a/tests/mindspore_test_framework/utils/block_util.py +++ b/tests/mindspore_test_framework/utils/block_util.py @@ -223,6 +223,10 @@ class InputOpNet(nn.Cell): x = self.op(x1, x2, x3, x4, x5, self.c1) return x + def construct5_c4(self, x1, x2, x3, x4, x5): + x = self.op(x1, x2, x3, x4, x5, self.c1, self.c2, self.c3, self.c4) + return x + def gen_net(op, input_num, training=True, desc_const=(), const_first=False, add_fake_input=False): if isinstance(op, nn.Cell): return op diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 117036c37e0..453ef9a652f 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -805,6 +805,18 @@ test_case_nn_ops = [ 'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3]], 'desc_bprop': [3, 3], 'skip': ['backward']}), + ('ApplyRMSProp', { + 'block': P.ApplyRMSProp(), + 'desc_const': [0.9, 0.0, 1e-10, 0.001], + 'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3]], + 'desc_bprop': [3, 3], + 'skip': ['backward']}), + ('ApplyCenteredRMSProp', { + 'block': P.ApplyCenteredRMSProp(), + 'desc_const': [0.9, 0.0, 1e-10, 0.001], + 'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3]], + 'desc_bprop': [3, 3], + 'skip': ['backward']}), ] test_case_array_ops = [