forked from mindspore-Ecosystem/mindspore
!839 add parameter verification for RMSprop
Merge pull request !839 from wangnan39/add_parameter_verification_for_rmsprop
This commit is contained in:
commit
883fde0494
|
@ -145,9 +145,12 @@ class Adam(Optimizer):
|
|||
When the learning_rate is float or learning_rate is a Tensor
|
||||
but the dims of the Tensor is 0, use fixed learning rate.
|
||||
Other cases are not supported. Default: 1e-3.
|
||||
beta1 (float): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0).
|
||||
beta2 (float): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0).
|
||||
eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0.
|
||||
beta1 (float): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0). Default:
|
||||
0.9.
|
||||
beta2 (float): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). Default:
|
||||
0.999.
|
||||
eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default:
|
||||
1e-8.
|
||||
use_locking (bool): Whether to enable a lock to protect updating variable tensors.
|
||||
If True, updating of the var, m, and v tensors will be protected by a lock.
|
||||
If False, the result is unpredictable. Default: False.
|
||||
|
@ -155,8 +158,8 @@ class Adam(Optimizer):
|
|||
If True, updates the gradients using NAG.
|
||||
If False, updates the gradients without using NAG. Default: False.
|
||||
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
|
||||
loss_scale (float): A floating point value for the loss scale. Default: 1.0.
|
||||
Should be equal to or greater than 1.
|
||||
loss_scale (float): A floating point value for the loss scale. Should be equal to or greater than 1. Default:
|
||||
1.0.
|
||||
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
|
||||
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
|
||||
|
||||
|
|
|
@ -46,8 +46,8 @@ class Optimizer(Cell):
|
|||
learning_rate (float): A floating point value for the learning rate. Should be greater than 0.
|
||||
parameters (list): A list of parameter, which will be updated. The element in `parameters`
|
||||
should be class mindspore.Parameter.
|
||||
weight_decay (float): A floating point value for the weight decay. If the type of `weight_decay`
|
||||
input is int, it will be convertd to float. Default: 0.0.
|
||||
weight_decay (float): A floating point value for the weight decay. It should be equal to or greater than 0.
|
||||
If the type of `weight_decay` input is int, it will be convertd to float. Default: 0.0.
|
||||
loss_scale (float): A floating point value for the loss scale. It should be greater than 0. If the
|
||||
type of `loss_scale` input is int, it will be convertd to float. Default: 1.0.
|
||||
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: lambda
|
||||
|
@ -87,21 +87,15 @@ class Optimizer(Cell):
|
|||
|
||||
if isinstance(weight_decay, int):
|
||||
weight_decay = float(weight_decay)
|
||||
|
||||
validator.check_float_legal_value('weight_decay', weight_decay, None)
|
||||
validator.check_value_type("weight_decay", weight_decay, [float], None)
|
||||
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, None)
|
||||
|
||||
if isinstance(loss_scale, int):
|
||||
loss_scale = float(loss_scale)
|
||||
validator.check_value_type("loss_scale", loss_scale, [float], None)
|
||||
validator.check_number_range("loss_scale", loss_scale, 0.0, float("inf"), Rel.INC_NEITHER, None)
|
||||
|
||||
validator.check_float_legal_value('loss_scale', loss_scale, None)
|
||||
|
||||
if loss_scale <= 0.0:
|
||||
raise ValueError("Loss scale should be greater than 0, but got {}".format(loss_scale))
|
||||
self.loss_scale = loss_scale
|
||||
|
||||
if weight_decay < 0.0:
|
||||
raise ValueError("Weight decay should be equal or greater than 0, but got {}".format(weight_decay))
|
||||
|
||||
self.learning_rate = Parameter(learning_rate, name="learning_rate")
|
||||
self.parameters = ParameterTuple(parameters)
|
||||
self.reciprocal_scale = 1.0 / loss_scale
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
"""rmsprop"""
|
||||
from mindspore.ops import functional as F, composite as C, operations as P
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from .optimizer import Optimizer
|
||||
|
||||
rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
|
||||
|
@ -91,14 +92,16 @@ class RMSProp(Optimizer):
|
|||
take the i-th value as the learning rate.
|
||||
When the learning_rate is float or learning_rate is a Tensor
|
||||
but the dims of the Tensor is 0, use fixed learning rate.
|
||||
Other cases are not supported.
|
||||
decay (float): Decay rate.
|
||||
momentum (float): Hyperparameter of type float, means momentum for the moving average.
|
||||
epsilon (float): Term added to the denominator to improve numerical stability. Should be greater than 0.
|
||||
Other cases are not supported. Default: 0.1.
|
||||
decay (float): Decay rate. Should be equal to or greater than 0. Default: 0.9.
|
||||
momentum (float): Hyperparameter of type float, means momentum for the moving average. Should be equal to or
|
||||
greater than 0.Default: 0.0.
|
||||
epsilon (float): Term added to the denominator to improve numerical stability. Should be greater than
|
||||
0. Default: 1e-10.
|
||||
use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False.
|
||||
centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False
|
||||
loss_scale (float): A floating point value for the loss scale. Default: 1.0.
|
||||
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
|
||||
centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False.
|
||||
loss_scale (float): A floating point value for the loss scale. Should be greater than 0. Default: 1.0.
|
||||
weight_decay (float): Weight decay (L2 penalty). Should be equal to or greater than 0. Default: 0.0.
|
||||
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
|
||||
lambda x: 'beta' not in x.name and 'gamma' not in x.name.
|
||||
|
||||
|
@ -118,17 +121,15 @@ class RMSProp(Optimizer):
|
|||
use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0,
|
||||
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
|
||||
super(RMSProp, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter)
|
||||
|
||||
if isinstance(momentum, float) and momentum < 0.0:
|
||||
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
|
||||
|
||||
if decay < 0.0:
|
||||
raise ValueError("decay should be at least 0.0, but got dampening {}".format(decay))
|
||||
self.decay = decay
|
||||
self.epsilon = epsilon
|
||||
|
||||
validator.check_value_type("decay", decay, [float], self.cls_name)
|
||||
validator.check_number_range("decay", decay, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
|
||||
validator.check_value_type("momentum", momentum, [float], self.cls_name)
|
||||
validator.check_number_range("momentum", momentum, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
|
||||
validator.check_value_type("epsilon", epsilon, [float], self.cls_name)
|
||||
validator.check_number_range("epsilon", epsilon, 0.0, float("inf"), Rel.INC_NEITHER, self.cls_name)
|
||||
validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
|
||||
validator.check_value_type("centered", centered, [bool], self.cls_name)
|
||||
|
||||
self.centered = centered
|
||||
if centered:
|
||||
self.opt = P.ApplyCenteredRMSProp(use_locking)
|
||||
|
@ -137,11 +138,10 @@ class RMSProp(Optimizer):
|
|||
self.opt = P.ApplyRMSProp(use_locking)
|
||||
|
||||
self.momentum = momentum
|
||||
|
||||
self.ms = self.parameters.clone(prefix="mean_square", init='zeros')
|
||||
self.moment = self.parameters.clone(prefix="moment", init='zeros')
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
self.epsilon = epsilon
|
||||
self.decay = decay
|
||||
|
||||
def construct(self, gradients):
|
||||
|
|
|
@ -49,12 +49,12 @@ class SGD(Optimizer):
|
|||
When the learning_rate is float or learning_rate is a Tensor
|
||||
but the dims of the Tensor is 0, use fixed learning rate.
|
||||
Other cases are not supported. Default: 0.1.
|
||||
momentum (float): A floating point value the momentum. Default: 0.
|
||||
dampening (float): A floating point value of dampening for momentum. Default: 0.
|
||||
weight_decay (float): Weight decay (L2 penalty). Default: 0.
|
||||
momentum (float): A floating point value the momentum. Default: 0.0.
|
||||
dampening (float): A floating point value of dampening for momentum. Default: 0.0.
|
||||
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
|
||||
nesterov (bool): Enables the Nesterov momentum. Default: False.
|
||||
loss_scale (float): A floating point value for the loss scale, which should be larger
|
||||
than 0.0. Default: 1.0.
|
||||
than 0.0. Default: 1.0.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test adam """
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn.optim import RMSProp
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
""" Net definition """
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.weight = Parameter(Tensor(np.ones([64, 10]).astype(np.float32)), name="weight")
|
||||
self.bias = Parameter(Tensor(np.ones([10]).astype((np.float32))), name="bias")
|
||||
self.matmul = P.MatMul()
|
||||
self.biasAdd = P.BiasAdd()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.biasAdd(self.matmul(x, self.weight), self.bias)
|
||||
return x
|
||||
|
||||
|
||||
def test_rmsprop_compile():
|
||||
""" test_adamw_compile """
|
||||
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
|
||||
label = Tensor(np.zeros([1, 10]).astype(np.float32))
|
||||
net = Net()
|
||||
net.set_train()
|
||||
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
optimizer = RMSProp(net.trainable_params(), learning_rate=0.1)
|
||||
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||
_executor.compile(train_network, inputs, label)
|
||||
|
||||
|
||||
def test_rmsprop_e():
|
||||
net = Net()
|
||||
with pytest.raises(ValueError):
|
||||
RMSProp(net.get_parameters(), momentum=-0.1, learning_rate=0.1)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
RMSProp(net.get_parameters(), momentum=1, learning_rate=0.1)
|
||||
|
Loading…
Reference in New Issue