forked from mindspore-Ecosystem/mindspore
!239 Add dynamic learning rate decay and review optimizer code
Merge pull request !239 from fanglei/master
This commit is contained in:
commit
60958d6b25
|
@ -0,0 +1,300 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""dynamic learning rate"""
|
||||
import math
|
||||
|
||||
from mindspore._checkparam import ParamValidator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
|
||||
|
||||
def piecewise_constant_lr(milestone, learning_rates):
|
||||
r"""
|
||||
Get piecewise constant learning rate.
|
||||
|
||||
Calculate learning rate by given `milestone` and `learning_rates`. Let the value of `milestone` be
|
||||
:math:`(M_1, M_2, ..., M_N)` and the value of `learning_rates` be :math:`(x_1, x_2, ..., x_N)`. N is the length of
|
||||
`milestone`. Let the output learning rate be `y`.
|
||||
|
||||
.. math::
|
||||
y[i] = x_t for i \in [M_{t-1}, M_t)
|
||||
|
||||
Args:
|
||||
milestone (list[int]): A list of milestone. This list is a monotone increasing list.
|
||||
learning_rates (list[float]): A list of learning rates.
|
||||
|
||||
Returns:
|
||||
list[float]. The size of list is :math:`M_N`.
|
||||
|
||||
Examples:
|
||||
>>> milestone = [2, 5, 10]
|
||||
>>> learning_rates = [0.1, 0.05, 0.01]
|
||||
>>> lr = piecewise_constant_lr(milestone, learning_rates)
|
||||
[0.1, 0.1, 0.05, 0.05, 0.05, 0.01, 0.01, 0.01, 0.01, 0.01]
|
||||
"""
|
||||
validator.check_type('milestone', milestone, (tuple, list))
|
||||
validator.check_type('learning_rates', learning_rates, (tuple, list))
|
||||
if len(milestone) != len(learning_rates):
|
||||
raise ValueError('The size of `milestone` must be same with the size of `learning_rates`.')
|
||||
|
||||
lr = []
|
||||
last_item = 0
|
||||
for i, item in enumerate(milestone):
|
||||
validator.check_integer(f'milestone[{i}]', item, 0, Rel.GT)
|
||||
validator.check_type(f'learning_rates[{i}]', learning_rates[i], [float])
|
||||
if item < last_item:
|
||||
raise ValueError(f'The value of milestone[{i}] must be greater than milestone[{i - 1}]')
|
||||
lr += [learning_rates[i]] * (item - last_item)
|
||||
last_item = item
|
||||
|
||||
return lr
|
||||
|
||||
|
||||
def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair):
|
||||
validator.check_integer('total_step', total_step, 0, Rel.GT)
|
||||
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT)
|
||||
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT)
|
||||
validator.check_float_positive('learning_rate', learning_rate)
|
||||
validator.check_float_positive('decay_rate', decay_rate)
|
||||
validator.check_type('is_stair', is_stair, [bool])
|
||||
|
||||
|
||||
def exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair=False):
|
||||
r"""
|
||||
Calculate learning rate base on exponential decay function.
|
||||
|
||||
For the i-th step, the formula of computing decayed_learning_rate[i] is:
|
||||
|
||||
.. math::
|
||||
decayed\_learning\_rate[i] = learning\_rate * decay\_rate^{\frac{current\_epoch}{decay\_epoch}}
|
||||
|
||||
Where :math:`current\_epoch=floor(\frac{i}{step\_per\_epoch})`.
|
||||
|
||||
Args:
|
||||
learning_rate (float): The initial value of learning rate.
|
||||
decay_rate (float): The decay rate.
|
||||
total_step (int): The total number of steps.
|
||||
step_per_epoch (int): The number of steps in per epoch.
|
||||
decay_epoch (int): A value used to calculate decayed learning rate.
|
||||
is_stair (bool): If true, learning rate decay once every `decay_epoch` times. Default: False.
|
||||
|
||||
Returns:
|
||||
list[float]. The size of list is `total_step`.
|
||||
|
||||
Examples:
|
||||
>>> learning_rate = 0.1
|
||||
>>> decay_rate = 0.9
|
||||
>>> total_step = 6
|
||||
>>> step_per_epoch = 2
|
||||
>>> decay_epoch = 1
|
||||
>>> lr = exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch)
|
||||
[0.1, 0.1, 0.09000000000000001, 0.09000000000000001, 0.08100000000000002, 0.08100000000000002]
|
||||
"""
|
||||
_check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair)
|
||||
|
||||
lr = []
|
||||
for i in range(total_step):
|
||||
if is_stair:
|
||||
lr.append(learning_rate * decay_rate ** math.floor(math.floor(i / step_per_epoch) / decay_epoch))
|
||||
else:
|
||||
lr.append(learning_rate * decay_rate ** (math.floor(i / step_per_epoch) / decay_epoch))
|
||||
return lr
|
||||
|
||||
|
||||
def natural_exp_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair=False):
|
||||
r"""
|
||||
Calculate learning rate base on natural exponential decay function.
|
||||
|
||||
For the i-th step, the formula of computing decayed_learning_rate[i] is:
|
||||
|
||||
.. math::
|
||||
decayed\_learning\_rate[i] = learning\_rate * e^{-decay\_rate * current\_epoch}
|
||||
|
||||
Where :math:`current\_epoch=floor(\frac{i}{step\_per\_epoch})`.
|
||||
|
||||
Args:
|
||||
learning_rate (float): The initial value of learning rate.
|
||||
decay_rate (float): The decay rate.
|
||||
total_step (int): The total number of steps.
|
||||
step_per_epoch (int): The number of steps in per epoch.
|
||||
decay_epoch (int): A value used to calculate decayed learning rate.
|
||||
is_stair (bool): If true, learning rate decay once every `decay_epoch` times. Default: False.
|
||||
|
||||
Returns:
|
||||
list[float]. The size of list is `total_step`.
|
||||
|
||||
Examples:
|
||||
>>> learning_rate = 0.1
|
||||
>>> decay_rate = 0.9
|
||||
>>> total_step = 6
|
||||
>>> step_per_epoch = 2
|
||||
>>> decay_epoch = 2
|
||||
>>> lr = natural_exp_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, True)
|
||||
[0.1, 0.1, 0.1, 0.1, 0.016529888822158657, 0.016529888822158657]
|
||||
"""
|
||||
_check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair)
|
||||
|
||||
function = lambda x, y: x
|
||||
if is_stair:
|
||||
function = lambda x, y: math.floor(x / y) * y
|
||||
|
||||
lr = []
|
||||
for i in range(total_step):
|
||||
lr.append(learning_rate * math.e ** (-decay_rate * function(math.floor(i / step_per_epoch), decay_epoch)))
|
||||
return lr
|
||||
|
||||
|
||||
def inverse_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair=False):
|
||||
r"""
|
||||
Calculate learning rate base on inverse-time decay function.
|
||||
|
||||
For the i-th step, the formula of computing decayed_learning_rate[i] is:
|
||||
|
||||
.. math::
|
||||
decayed\_learning\_rate[i] = learning\_rate / (1 + decay\_rate * current\_epoch / decay\_epoch)
|
||||
|
||||
Where :math:`current\_epoch=floor(\frac{i}{step\_per\_epoch})`.
|
||||
|
||||
Args:
|
||||
learning_rate (float): The initial value of learning rate.
|
||||
decay_rate (float): The decay rate.
|
||||
total_step (int): The total number of steps.
|
||||
step_per_epoch (int): The number of steps in per epoch.
|
||||
decay_epoch (int): A value used to calculate decayed learning rate.
|
||||
is_stair (bool): If true, learning rate decay once every `decay_epoch` times. Default: False.
|
||||
|
||||
Returns:
|
||||
list[float]. The size of list is `total_step`.
|
||||
|
||||
Examples:
|
||||
>>> learning_rate = 0.1
|
||||
>>> decay_rate = 0.5
|
||||
>>> total_step = 6
|
||||
>>> step_per_epoch = 1
|
||||
>>> decay_epoch = 1
|
||||
>>> lr = inverse_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, True)
|
||||
[0.1, 0.06666666666666667, 0.05, 0.04, 0.03333333333333333, 0.028571428571428574]
|
||||
"""
|
||||
_check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair)
|
||||
|
||||
lr = []
|
||||
for i in range(total_step):
|
||||
if is_stair:
|
||||
lr.append(learning_rate / (1 + decay_rate * math.floor(math.floor(i / step_per_epoch) / decay_epoch)))
|
||||
else:
|
||||
lr.append(learning_rate / (1 + decay_rate * math.floor(i / step_per_epoch) / decay_epoch))
|
||||
return lr
|
||||
|
||||
|
||||
def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch):
|
||||
r"""
|
||||
Calculate learning rate base on cosine decay function.
|
||||
|
||||
For the i-th step, the formula of computing decayed_learning_rate[i] is:
|
||||
|
||||
.. math::
|
||||
decayed\_learning\_rate[i] = min\_learning\_rate + 0.5 * (max\_learning\_rate - min\_learning\_rate) *
|
||||
(1 + cos(\frac{current\_epoch}{decay\_epoch}\pi))
|
||||
|
||||
Where :math:`current\_epoch=floor(\frac{i}{step\_per\_epoch})`.
|
||||
|
||||
Args:
|
||||
min_lr (float): The minimum value of learning rate.
|
||||
max_lr (float): The maximum value of learning rate.
|
||||
total_step (int): The total number of steps.
|
||||
step_per_epoch (int): The number of steps in per epoch.
|
||||
decay_epoch (int): A value used to calculate decayed learning rate.
|
||||
|
||||
Returns:
|
||||
list[float]. The size of list is `total_step`.
|
||||
|
||||
Examples:
|
||||
>>> min_lr = 0.01
|
||||
>>> max_lr = 0.1
|
||||
>>> total_step = 6
|
||||
>>> step_per_epoch = 2
|
||||
>>> decay_epoch = 2
|
||||
>>> lr = cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch)
|
||||
[0.1, 0.1, 0.05500000000000001, 0.05500000000000001, 0.01, 0.01]
|
||||
"""
|
||||
validator.check_float_positive('min_lr', min_lr)
|
||||
validator.check_float_positive('max_lr', max_lr)
|
||||
validator.check_integer('total_step', total_step, 0, Rel.GT)
|
||||
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT)
|
||||
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT)
|
||||
|
||||
delta = 0.5 * (max_lr - min_lr)
|
||||
lr = []
|
||||
for i in range(total_step):
|
||||
tmp_epoch = min(math.floor(i / step_per_epoch), decay_epoch)
|
||||
lr.append(min_lr + delta * (1 + math.cos(math.pi * tmp_epoch / decay_epoch)))
|
||||
return lr
|
||||
|
||||
|
||||
def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power,
|
||||
update_decay_epoch=False):
|
||||
r"""
|
||||
Calculate learning rate base on polynomial decay function.
|
||||
|
||||
For the i-th step, the formula of computing decayed_learning_rate[i] is:
|
||||
|
||||
.. math::
|
||||
decayed\_learning\_rate[i] = (learning\_rate - end\_learning\_rate) *
|
||||
(1 - tmp\_epoch / decay\_epoch)^{power} + end\_learning\_rate
|
||||
|
||||
Where :math:`tmp\_epoch=min(current\_epoch, decay\_epoch), current\_epoch=floor(\frac{i}{step\_per\_epoch})`.
|
||||
If `update_decay_epoch` is true, update the value of `decay_epoch` every epoch. The formula is
|
||||
:math:`decay\_epoch = decay\_epoch * ceil(current\_epoch / decay\_epoch)`
|
||||
|
||||
Args:
|
||||
learning_rate (float): The initial value of learning rate.
|
||||
end_learning_rate (float): The end value of learning rate.
|
||||
total_step (int): The total number of steps.
|
||||
step_per_epoch (int): The number of steps in per epoch.
|
||||
decay_epoch (int): A value used to calculate decayed learning rate.
|
||||
power (float): A value used to calculate decayed learning rate.
|
||||
update_decay_epoch (bool): If true, update `decay_epoch`. Default: False.
|
||||
|
||||
Returns:
|
||||
list[float]. The size of list is `total_step`.
|
||||
|
||||
Examples:
|
||||
>>> learning_rate = 0.1
|
||||
>>> end_learning_rate = 0.01
|
||||
>>> total_step = 6
|
||||
>>> step_per_epoch = 2
|
||||
>>> decay_epoch = 2
|
||||
>>> power = 0.5
|
||||
>>> lr = polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power)
|
||||
[0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01]
|
||||
"""
|
||||
validator.check_float_positive('learning_rate', learning_rate)
|
||||
validator.check_float_positive('end_learning_rate', end_learning_rate)
|
||||
validator.check_integer('total_step', total_step, 0, Rel.GT)
|
||||
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT)
|
||||
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT)
|
||||
validator.check_type('power', power, [float])
|
||||
validator.check_type('update_decay_epoch', update_decay_epoch, [bool])
|
||||
|
||||
function = lambda x, y: (x, min(x, y))
|
||||
if update_decay_epoch:
|
||||
function = lambda x, y: (x * max(math.ceil(y / x), 1), y)
|
||||
|
||||
lr = []
|
||||
delta = learning_rate - end_learning_rate
|
||||
for i in range(total_step):
|
||||
current_epoch = math.floor(i / step_per_epoch)
|
||||
decay_epoch, tmp_epoch = function(decay_epoch, current_epoch)
|
||||
lr.append(delta * (1 - tmp_epoch / decay_epoch) ** power + end_learning_rate)
|
||||
return lr
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""adam"""
|
||||
from typing import Iterable
|
||||
import numpy as np
|
||||
|
||||
from mindspore.common import dtype as mstype
|
||||
|
@ -25,7 +24,7 @@ from mindspore.common.parameter import Parameter
|
|||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._checkparam import ParamValidator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from .optimizer import Optimizer, apply_decay, grad_scale
|
||||
from .optimizer import Optimizer
|
||||
|
||||
_learning_rate_update_func = ['linear', 'cos', 'sin']
|
||||
|
||||
|
@ -168,22 +167,13 @@ class Adam(Optimizer):
|
|||
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
|
||||
use_nesterov=False, weight_decay=0.0, loss_scale=1.0,
|
||||
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
|
||||
super(Adam, self).__init__(learning_rate, params)
|
||||
super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter)
|
||||
_check_param_value(beta1, beta2, eps, weight_decay)
|
||||
validator.check_type("use_locking", use_locking, [bool])
|
||||
validator.check_type("use_nesterov", use_nesterov, [bool])
|
||||
validator.check_type("loss_scale", loss_scale, [float])
|
||||
validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT)
|
||||
|
||||
self.dynamic_lr = False
|
||||
if isinstance(learning_rate, Iterable) or \
|
||||
(isinstance(learning_rate, Tensor) and learning_rate.dim() == 1):
|
||||
self.dynamic_lr = True
|
||||
self.gather = P.GatherV2()
|
||||
self.assignadd = P.AssignAdd()
|
||||
self.global_step = Parameter(initializer(0, [1], mstype.int32), name="global_step")
|
||||
self.axis = 0
|
||||
|
||||
self.beta1 = Tensor(beta1, mstype.float32)
|
||||
self.beta2 = Tensor(beta2, mstype.float32)
|
||||
self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power")
|
||||
|
@ -196,8 +186,6 @@ class Adam(Optimizer):
|
|||
self.decay_tf = tuple(decay_filter(x) for x in self.parameters)
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.opt = P.Adam(use_locking, use_nesterov)
|
||||
self.weight_decay = weight_decay * loss_scale
|
||||
self.reciprocal_scale = 1.0 / loss_scale
|
||||
|
||||
self.pow = P.Pow()
|
||||
self.sqrt = P.Sqrt()
|
||||
|
@ -208,15 +196,9 @@ class Adam(Optimizer):
|
|||
params = self.parameters
|
||||
moment1 = self.moment1
|
||||
moment2 = self.moment2
|
||||
if self.weight_decay > 0:
|
||||
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_tf, params, gradients)
|
||||
if self.reciprocal_scale != 1.0:
|
||||
gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)
|
||||
|
||||
lr = self.learning_rate
|
||||
if self.dynamic_lr:
|
||||
lr = self.gather(self.learning_rate, self.global_step, self.axis)
|
||||
F.control_depend(lr, self.assignadd(self.global_step, self.one))
|
||||
gradients = self.decay_weight(gradients)
|
||||
gradients = self.scale_grad(gradients)
|
||||
lr = self.get_lr()
|
||||
|
||||
beta1_power = self.beta1_power * self.beta1
|
||||
self.beta1_power = beta1_power
|
||||
|
|
|
@ -13,14 +13,9 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""momentum"""
|
||||
from typing import Iterable
|
||||
|
||||
from mindspore.ops import functional as F, composite as C, operations as P
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common import Tensor
|
||||
from .optimizer import Optimizer, apply_decay, grad_scale
|
||||
from .optimizer import Optimizer
|
||||
|
||||
momentum_opt = C.MultitypeFuncGraph("momentum_opt")
|
||||
|
||||
|
@ -88,43 +83,20 @@ class Momentum(Optimizer):
|
|||
"""
|
||||
def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0,
|
||||
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
|
||||
super(Momentum, self).__init__(learning_rate, params)
|
||||
super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter)
|
||||
if isinstance(momentum, float) and momentum < 0.0:
|
||||
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
|
||||
if isinstance(learning_rate, Iterable) or \
|
||||
(isinstance(learning_rate, Tensor) and learning_rate.dim() == 1):
|
||||
self.dynamic_lr = True
|
||||
self.gather = P.GatherV2()
|
||||
self.assignadd = P.AssignAdd()
|
||||
self.global_step = Parameter(initializer(0, [1], mstype.int32), name="global_step")
|
||||
self.axis = 0
|
||||
else:
|
||||
self.dynamic_lr = False
|
||||
self.gather = None
|
||||
self.assignadd = None
|
||||
self.global_step = None
|
||||
self.axis = None
|
||||
self.momentum = Parameter(momentum, name="momentum")
|
||||
self.params = self.parameters
|
||||
self.moments = self.params.clone(prefix="moments", init='zeros')
|
||||
self.decay_tf = tuple(decay_filter(x) for x in self.parameters)
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.opt = P.ApplyMomentum()
|
||||
self.weight_decay = weight_decay * loss_scale
|
||||
self.reciprocal_scale = 1.0 / loss_scale
|
||||
self.one = Tensor(1, mstype.int32)
|
||||
|
||||
def construct(self, gradients):
|
||||
params = self.params
|
||||
moments = self.moments
|
||||
if self.weight_decay > 0:
|
||||
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_tf, params, gradients)
|
||||
if self.reciprocal_scale != 1.0:
|
||||
gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)
|
||||
if self.dynamic_lr:
|
||||
lr = self.gather(self.learning_rate, self.global_step, self.axis)
|
||||
F.control_depend(lr, self.assignadd(self.global_step, self.one))
|
||||
else:
|
||||
lr = self.learning_rate
|
||||
gradients = self.decay_weight(gradients)
|
||||
gradients = self.scale_grad(gradients)
|
||||
lr = self.get_lr()
|
||||
success = self.hyper_map(F.partial(momentum_opt, self.opt, lr, self.momentum), gradients, params, moments)
|
||||
return success
|
||||
|
|
|
@ -17,9 +17,11 @@ from typing import Iterable
|
|||
|
||||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
from mindspore.ops import functional as F, composite as C, operations as P
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore._checkparam import ParamValidator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
@ -42,34 +44,110 @@ class Optimizer(Cell):
|
|||
Args:
|
||||
learning_rate (float): A floating point value for the learning rate. Should be greater than 0.
|
||||
parameters (list): A list of parameter, which will be updated. The element in `parameters`
|
||||
should be class mindspore.Parameter.
|
||||
should be class mindspore.Parameter.
|
||||
weight_decay (float): A floating point value for the weight decay. Default: 0.0.
|
||||
loss_scale (float): A floating point value for the loss scale. Default: 1.0. Should be greater than 0.
|
||||
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: lambda
|
||||
x: 'beta' not in x.name and 'gamma' not in x.name.
|
||||
|
||||
Raises:
|
||||
ValueError: If the learning_rate is a Tensor, but the dims of tensor is greater than 1.
|
||||
TypeError: If the learning_rate is not any of the three types: float, Tensor, Iterable.
|
||||
"""
|
||||
|
||||
def __init__(self, learning_rate, parameters):
|
||||
def __init__(self, learning_rate, parameters, weight_decay=0.0, loss_scale=1.0,
|
||||
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
|
||||
super(Optimizer, self).__init__()
|
||||
if isinstance(learning_rate, float):
|
||||
self.dynamic_lr = False
|
||||
self.gather = None
|
||||
self.assignadd = None
|
||||
self.global_step = None
|
||||
validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT)
|
||||
elif isinstance(learning_rate, Iterable):
|
||||
learning_rate = Tensor(np.array(list(learning_rate)).astype(np.float32))
|
||||
elif isinstance(learning_rate, Tensor):
|
||||
if learning_rate.dim() > 1:
|
||||
raise ValueError("Learning rate should be a 0 or 1 dim `Tensor`,"
|
||||
f"but got {learning_rate.dim()}.")
|
||||
else:
|
||||
raise TypeError("Learning rate should be float, Tensor or Iterable.")
|
||||
self.dynamic_lr = True
|
||||
self.gather = P.GatherV2()
|
||||
self.assignadd = P.AssignAdd()
|
||||
self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step')
|
||||
if isinstance(learning_rate, Iterable):
|
||||
learning_rate = Tensor(np.array(list(learning_rate)).astype(np.float32))
|
||||
elif isinstance(learning_rate, Tensor):
|
||||
if learning_rate.dim() > 1:
|
||||
raise ValueError("Learning rate should be a 0 or 1 dim `Tensor`,"
|
||||
f"but got {learning_rate.dim()}.")
|
||||
if learning_rate.dim() == 1 and learning_rate.size() < 2:
|
||||
logger.warning("If want to use the dynamic learning rate, please make sure that the number "
|
||||
"of elements in the list, tuple or tensor passed is greater than 1.")
|
||||
else:
|
||||
raise TypeError("Learning rate should be float, Tensor or Iterable.")
|
||||
|
||||
if loss_scale <= 0.0:
|
||||
raise ValueError("Loss scale should be greater than 0, but got {}".format(loss_scale))
|
||||
if weight_decay < 0.0:
|
||||
raise ValueError("Weight decay should be equal or greater than 0, but got {}".format(weight_decay))
|
||||
|
||||
if isinstance(learning_rate, Tensor) and learning_rate.dim() == 1 and learning_rate.size() < 2:
|
||||
logger.warning("If want to use the dynamic learning rate, please make sure that "
|
||||
"the number of elements in the list, tuple or tensor passed is greater than 1.")
|
||||
self.learning_rate = Parameter(learning_rate, name="learning_rate")
|
||||
self.parameters = ParameterTuple(parameters)
|
||||
self.reciprocal_scale = 1.0 / loss_scale
|
||||
self.weight_decay = weight_decay * loss_scale
|
||||
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
|
||||
|
||||
if not self.parameters:
|
||||
raise ValueError("optimizer got an empty parameter list.")
|
||||
|
||||
def decay_weight(self, gradients):
|
||||
"""
|
||||
Weight decay.
|
||||
|
||||
An approach to reduce the overfitting of a deep learning neural network model.
|
||||
|
||||
Args:
|
||||
gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape with
|
||||
`self.parameters`.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor], The gradients after weight decay.
|
||||
"""
|
||||
if self.weight_decay > 0:
|
||||
params = self.params
|
||||
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients)
|
||||
|
||||
return gradients
|
||||
|
||||
def scale_grad(self, gradients):
|
||||
"""
|
||||
Loss scale for mixed precision.
|
||||
|
||||
An approach of mixed precision training to improve the speed and energy efficiency of training deep neural
|
||||
network.
|
||||
|
||||
Args:
|
||||
gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape with
|
||||
`self.parameters`.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor], The gradients after loss scale.
|
||||
|
||||
"""
|
||||
if self.reciprocal_scale != 1.0:
|
||||
gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)
|
||||
|
||||
return gradients
|
||||
|
||||
def get_lr(self):
|
||||
"""
|
||||
Get the learning rate of current step.
|
||||
|
||||
Returns:
|
||||
float, the learning rate of current step.
|
||||
"""
|
||||
lr = self.learning_rate
|
||||
if self.dynamic_lr:
|
||||
lr = self.gather(self.learning_rate, self.global_step, 0)
|
||||
F.control_depend(lr, self.assignadd(self.global_step, 1))
|
||||
|
||||
return lr
|
||||
|
||||
def construct(self, *hyper_params):
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -14,12 +14,8 @@
|
|||
# ============================================================================
|
||||
"""rmsprop"""
|
||||
from mindspore.ops import functional as F, composite as C, operations as P
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore._checkparam import ParamValidator as validator
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common import Tensor
|
||||
from .optimizer import Optimizer, grad_scale, apply_decay
|
||||
from .optimizer import Optimizer
|
||||
|
||||
rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
|
||||
centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
|
||||
|
@ -138,7 +134,7 @@ class RMSProp(Optimizer):
|
|||
def __init__(self, params, learning_rate=0.1, decay=0.9, momentum=0.0, epsilon=1e-10,
|
||||
use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0,
|
||||
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
|
||||
super(RMSProp, self).__init__(learning_rate, params)
|
||||
super(RMSProp, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter)
|
||||
|
||||
if isinstance(momentum, float) and momentum < 0.0:
|
||||
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
|
||||
|
@ -157,15 +153,6 @@ class RMSProp(Optimizer):
|
|||
else:
|
||||
self.opt = P.ApplyRMSProp(use_locking)
|
||||
|
||||
self.dynamic_lr = False
|
||||
if not isinstance(learning_rate, float):
|
||||
self.dynamic_lr = True
|
||||
self.gather = P.GatherV2()
|
||||
self.assignadd = P.AssignAdd()
|
||||
self.global_step = Parameter(initializer(0, [1], mstype.int32), name="global_step")
|
||||
self.axis = 0
|
||||
self.one = Tensor(1, mstype.int32)
|
||||
|
||||
self.momentum = momentum
|
||||
|
||||
self.ms = self.parameters.clone(prefix="mean_square", init='zeros')
|
||||
|
@ -173,21 +160,12 @@ class RMSProp(Optimizer):
|
|||
self.hyper_map = C.HyperMap()
|
||||
|
||||
self.decay = decay
|
||||
self.decay_tf = tuple(decay_filter(x) for x in self.parameters)
|
||||
self.reciprocal_scale = 1.0 / loss_scale
|
||||
self.weight_decay = weight_decay * loss_scale
|
||||
|
||||
def construct(self, gradients):
|
||||
params = self.parameters
|
||||
if self.weight_decay > 0:
|
||||
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_tf, params, gradients)
|
||||
if self.reciprocal_scale != 1.0:
|
||||
gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)
|
||||
if self.dynamic_lr:
|
||||
lr = self.gather(self.learning_rate, self.global_step, self.axis)
|
||||
F.control_depend(lr, self.assignadd(self.global_step, self.one))
|
||||
else:
|
||||
lr = self.learning_rate
|
||||
gradients = self.decay_weight(gradients)
|
||||
gradients = self.scale_grad(gradients)
|
||||
lr = self.get_lr()
|
||||
if self.centered:
|
||||
success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, lr, self.decay, self.epsilon,
|
||||
self.momentum), params, self.mg, self.ms, self.moment, gradients)
|
||||
|
|
|
@ -14,11 +14,9 @@
|
|||
# ============================================================================
|
||||
"""sgd"""
|
||||
from mindspore.ops import functional as F, composite as C, operations as P
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore._checkparam import ParamValidator as validator
|
||||
import mindspore.common.dtype as mstype
|
||||
from .optimizer import Optimizer, grad_scale
|
||||
from .optimizer import Optimizer
|
||||
|
||||
sgd_opt = C.MultitypeFuncGraph("sgd_opt")
|
||||
|
||||
|
@ -83,7 +81,7 @@ class SGD(Optimizer):
|
|||
def __init__(self, params, learning_rate=0.1, momentum=0.0, dampening=0.0, weight_decay=0.0, nesterov=False,
|
||||
loss_scale=1.0):
|
||||
|
||||
super(SGD, self).__init__(learning_rate, params)
|
||||
super(SGD, self).__init__(learning_rate, params, weight_decay, loss_scale)
|
||||
|
||||
if isinstance(momentum, float) and momentum < 0.0:
|
||||
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
|
||||
|
@ -92,44 +90,22 @@ class SGD(Optimizer):
|
|||
raise ValueError("dampening should be at least 0.0, but got dampening {}".format(dampening))
|
||||
self.dampening = dampening
|
||||
|
||||
if weight_decay < 0.0:
|
||||
raise ValueError("weight_decay should be at least 0.0, but got weight_decay {}".format(weight_decay))
|
||||
self.weight_decay = weight_decay
|
||||
|
||||
validator.check_type("nesterov", nesterov, [bool])
|
||||
self.nesterov = nesterov
|
||||
|
||||
self.opt = P.SGD(dampening, weight_decay, nesterov)
|
||||
|
||||
self.dynamic_lr = False
|
||||
self.gather = None
|
||||
self.global_step = None
|
||||
self.axis = None
|
||||
if not isinstance(learning_rate, float):
|
||||
self.dynamic_lr = True
|
||||
self.gather = P.GatherV2()
|
||||
self.assignadd = P.AssignAdd()
|
||||
self.global_step = Parameter(initializer(0, [1], mstype.int32), name="global_step")
|
||||
self.axis = 0
|
||||
self.momentum = Parameter(momentum, name="momentum")
|
||||
self.params = self.parameters
|
||||
self.accum = self.params.clone(prefix="accum", init='zeros')
|
||||
self.stat = self.params.clone(prefix="stat", init='ones')
|
||||
self.accum = self.parameters.clone(prefix="accum", init='zeros')
|
||||
self.stat = self.parameters.clone(prefix="stat", init='ones')
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
self.weight_decay = weight_decay * loss_scale
|
||||
self.reciprocal_scale = 1.0 / loss_scale
|
||||
|
||||
def construct(self, gradients):
|
||||
params = self.params
|
||||
params = self.parameters
|
||||
accum = self.accum
|
||||
stat = self.stat
|
||||
if self.reciprocal_scale != 1.0:
|
||||
gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)
|
||||
if self.dynamic_lr:
|
||||
lr = self.gather(self.learning_rate, self.global_step, self.axis)
|
||||
F.control_depend(lr, self.assignadd(self.global_step, 1))
|
||||
else:
|
||||
lr = self.learning_rate
|
||||
gradients = self.decay_weight(gradients)
|
||||
gradients = self.scale_grad(gradients)
|
||||
lr = self.get_lr()
|
||||
success = self.hyper_map(F.partial(sgd_opt, self.opt, lr, self.momentum), gradients, params, accum, stat)
|
||||
return success
|
||||
|
|
|
@ -15,17 +15,11 @@
|
|||
""" test optimizer """
|
||||
import numpy as np
|
||||
import pytest
|
||||
from mindspore.nn.optim import Optimizer, SGD, Adam, AdamWeightDecay, AdamWeightDecayDynamicLR
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn.optim import Optimizer, SGD, Adam, AdamWeightDecay, AdamWeightDecayDynamicLR
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
|
||||
gradient = Tensor(np.zeros([1, 2, 3]))
|
||||
accumulation = gradient
|
||||
variable = accumulation
|
||||
|
||||
|
||||
paramsTensor = Tensor(np.zeros([1, 2, 3]))
|
||||
class IterableObjc:
|
||||
def __iter__(self):
|
||||
cont = 0
|
||||
|
@ -56,6 +50,7 @@ class TestAdam():
|
|||
|
||||
def test_construct(self):
|
||||
with pytest.raises(TypeError):
|
||||
gradient = Tensor(np.zeros([1, 2, 3]))
|
||||
adam = Adam(params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
|
||||
use_nesterov=False, weight_decay=0.0, loss_scale=1.0)
|
||||
adam.construct(gradient)
|
||||
|
@ -105,4 +100,5 @@ class TestUnsupportParam():
|
|||
|
||||
def test_Sgd_init(self):
|
||||
with pytest.raises(TypeError):
|
||||
paramsTensor = Tensor(np.zeros([1, 2, 3]))
|
||||
SGD(paramsTensor)
|
||||
|
|
|
@ -0,0 +1,234 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" Test Dynamic Learning Rate """
|
||||
import pytest
|
||||
import mindspore
|
||||
from mindspore.nn import dynamic_lr as dr
|
||||
|
||||
milestone = [10, 20, 30]
|
||||
learning_rates = [0.1, 0.05, 0.01]
|
||||
learning_rate = 0.1
|
||||
end_learning_rate = 0.01
|
||||
decay_rate = 0.9
|
||||
total_step = 30
|
||||
step_per_epoch = 3
|
||||
decay_epoch = 2
|
||||
min_lr = 0.01
|
||||
max_lr = 0.1
|
||||
power = 0.5
|
||||
|
||||
class TestInputs:
|
||||
def test_milestone1(self):
|
||||
milestone1 = 1
|
||||
with pytest.raises(ValueError):
|
||||
dr.piecewise_constant_lr(milestone1, learning_rates)
|
||||
|
||||
def test_milestone2(self):
|
||||
milestone1 = [20, 10, 1]
|
||||
with pytest.raises(ValueError):
|
||||
dr.piecewise_constant_lr(milestone1, learning_rates)
|
||||
|
||||
milestone2 = [1.0, 2.0, True]
|
||||
with pytest.raises(ValueError):
|
||||
dr.piecewise_constant_lr(milestone2, learning_rates)
|
||||
|
||||
def test_learning_rates1(self):
|
||||
lr = True
|
||||
with pytest.raises(ValueError):
|
||||
dr.piecewise_constant_lr(milestone, lr)
|
||||
|
||||
def test_learning_rates2(self):
|
||||
lr = [1, 2, 1]
|
||||
with pytest.raises(ValueError):
|
||||
dr.piecewise_constant_lr(milestone, lr)
|
||||
|
||||
def test_learning_rate_type(self):
|
||||
lr = True
|
||||
with pytest.raises(TypeError):
|
||||
dr.exponential_decay_lr(lr, decay_rate, total_step, step_per_epoch, decay_epoch)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
dr.polynomial_decay_lr(lr, end_learning_rate, total_step, step_per_epoch, decay_epoch, power)
|
||||
|
||||
def test_learning_rate_value(self):
|
||||
lr = -1.0
|
||||
with pytest.raises(ValueError):
|
||||
dr.exponential_decay_lr(lr, decay_rate, total_step, step_per_epoch, decay_epoch)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
dr.polynomial_decay_lr(lr, end_learning_rate, total_step, step_per_epoch, decay_epoch, power)
|
||||
|
||||
def test_end_learning_rate_type(self):
|
||||
lr = True
|
||||
with pytest.raises(TypeError):
|
||||
dr.polynomial_decay_lr(learning_rate, lr, total_step, step_per_epoch, decay_epoch, power)
|
||||
|
||||
def test_end_learning_rate_value(self):
|
||||
lr = -1.0
|
||||
with pytest.raises(ValueError):
|
||||
dr.polynomial_decay_lr(learning_rate, lr, total_step, step_per_epoch, decay_epoch, power)
|
||||
|
||||
def test_decay_rate_type(self):
|
||||
rate = 'a'
|
||||
with pytest.raises(TypeError):
|
||||
dr.exponential_decay_lr(learning_rate, rate, total_step, step_per_epoch, decay_epoch)
|
||||
|
||||
def test_decay_rate_value(self):
|
||||
rate = -1.0
|
||||
with pytest.raises(ValueError):
|
||||
dr.exponential_decay_lr(learning_rate, rate, total_step, step_per_epoch, decay_epoch)
|
||||
|
||||
def test_total_step1(self):
|
||||
total_step1 = 2.0
|
||||
with pytest.raises(ValueError):
|
||||
dr.exponential_decay_lr(learning_rate, decay_rate, total_step1, step_per_epoch, decay_epoch)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
dr.cosine_decay_lr(min_lr, max_lr, total_step1, step_per_epoch, decay_epoch)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step1, step_per_epoch, decay_epoch, power)
|
||||
|
||||
def test_total_step2(self):
|
||||
total_step1 = -1
|
||||
with pytest.raises(ValueError):
|
||||
dr.exponential_decay_lr(learning_rate, decay_rate, total_step1, step_per_epoch, decay_epoch)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
dr.cosine_decay_lr(min_lr, max_lr, total_step1, step_per_epoch, decay_epoch)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step1, step_per_epoch, decay_epoch, power)
|
||||
|
||||
def test_step_per_epoch1(self):
|
||||
step_per_epoch1 = True
|
||||
with pytest.raises(ValueError):
|
||||
dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch1, decay_epoch)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
dr.cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch1, decay_epoch)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch1, decay_epoch, power)
|
||||
|
||||
def test_step_per_epoch2(self):
|
||||
step_per_epoch1 = -1
|
||||
with pytest.raises(ValueError):
|
||||
dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch1, decay_epoch)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
dr.cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch1, decay_epoch)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch1, decay_epoch, power)
|
||||
|
||||
def test_decay_epoch1(self):
|
||||
decay_epoch1 = 'm'
|
||||
with pytest.raises(ValueError):
|
||||
dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch1)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
dr.cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch1)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch1, power)
|
||||
|
||||
def test_decay_epoch2(self):
|
||||
decay_epoch1 = -1
|
||||
with pytest.raises(ValueError):
|
||||
dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch1)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
dr.cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch1)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch1, power)
|
||||
|
||||
def test_is_stair(self):
|
||||
is_stair = 1
|
||||
with pytest.raises(ValueError):
|
||||
dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair)
|
||||
|
||||
def test_min_lr_type(self):
|
||||
min_lr1 = True
|
||||
with pytest.raises(TypeError):
|
||||
dr.cosine_decay_lr(min_lr1, max_lr, total_step, step_per_epoch, decay_epoch)
|
||||
|
||||
def test_min_lr_value(self):
|
||||
min_lr1 = -1.0
|
||||
with pytest.raises(ValueError):
|
||||
dr.cosine_decay_lr(min_lr1, max_lr, total_step, step_per_epoch, decay_epoch)
|
||||
|
||||
def test_max_lr_type(self):
|
||||
max_lr1 = 'a'
|
||||
with pytest.raises(TypeError):
|
||||
dr.cosine_decay_lr(min_lr, max_lr1, total_step, step_per_epoch, decay_epoch)
|
||||
|
||||
def test_max_lr_value(self):
|
||||
max_lr1 = -1.0
|
||||
with pytest.raises(ValueError):
|
||||
dr.cosine_decay_lr(min_lr, max_lr1, total_step, step_per_epoch, decay_epoch)
|
||||
|
||||
def test_power(self):
|
||||
power1 = True
|
||||
with pytest.raises(ValueError):
|
||||
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power1)
|
||||
|
||||
def test_update_decay_epoch(self):
|
||||
update_decay_epoch = 1
|
||||
with pytest.raises(ValueError):
|
||||
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch,
|
||||
power, update_decay_epoch)
|
||||
|
||||
|
||||
def test_learning_rate():
|
||||
lr = dr.piecewise_constant_lr(milestone, learning_rates)
|
||||
assert len(lr) == milestone[-1]
|
||||
|
||||
|
||||
def test_exponential_decay():
|
||||
lr1 = dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch)
|
||||
assert len(lr1) == total_step
|
||||
|
||||
lr2 = dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, True)
|
||||
assert len(lr2) == total_step
|
||||
|
||||
|
||||
def test_enatural_exp_decay():
|
||||
lr1 = dr.natural_exp_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch)
|
||||
assert len(lr1) == total_step
|
||||
|
||||
lr2 = dr.natural_exp_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, True)
|
||||
assert len(lr2) == total_step
|
||||
|
||||
|
||||
def test_inverse_decay():
|
||||
lr1 = dr.inverse_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch)
|
||||
assert len(lr1) == total_step
|
||||
|
||||
lr2 = dr.inverse_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, True)
|
||||
assert len(lr2) == total_step
|
||||
|
||||
|
||||
def test_cosine_decay():
|
||||
lr = dr.cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch)
|
||||
assert len(lr) == total_step
|
||||
|
||||
def test_polynomial_decay():
|
||||
lr1 = dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power)
|
||||
assert len(lr1) == total_step
|
||||
lr2 = dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power,
|
||||
True)
|
||||
assert len(lr2) == total_step
|
Loading…
Reference in New Issue