forked from mindspore-Ecosystem/mindspore
!25557 Add ASGD and Rprop optimizer to Mindspore
Merge pull request !25557 from wanyiming/optimizer
This commit is contained in:
commit
15cc4673a0
|
@ -23,6 +23,8 @@ from .momentum import Momentum
|
|||
from .adam import Adam, AdamWeightDecay, AdamOffload
|
||||
from .lamb import Lamb
|
||||
from .sgd import SGD
|
||||
from .asgd import ASGD
|
||||
from .rprop import Rprop
|
||||
from .lars import LARS
|
||||
from .ftrl import FTRL
|
||||
from .rmsprop import RMSProp
|
||||
|
@ -33,4 +35,4 @@ from .thor import thor
|
|||
from .adafactor import AdaFactor
|
||||
|
||||
__all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'AdamWeightDecay', 'LazyAdam', 'AdamOffload',
|
||||
'Lamb', 'SGD', 'FTRL', 'RMSProp', 'ProximalAdagrad', 'Adagrad', 'thor', 'AdaFactor']
|
||||
'Lamb', 'SGD', 'ASGD', 'Rprop', 'FTRL', 'RMSProp', 'ProximalAdagrad', 'Adagrad', 'thor', 'AdaFactor']
|
||||
|
|
|
@ -0,0 +1,190 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""asgd"""
|
||||
from mindspore.ops import functional as F, operations as P
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from .optimizer import Optimizer
|
||||
from .optimizer import opt_init_args_register
|
||||
|
||||
class ASGD(Optimizer):
|
||||
r"""
|
||||
Implements Average Stochastic Gradient Descent.
|
||||
|
||||
Introduction to ASGD can be found at `Acceleration of stochastic approximation by average
|
||||
<http://dl.acm.org/citation.cfm?id=131098>`_.
|
||||
|
||||
The updating formulas are as follows:
|
||||
|
||||
\begin{gather*}
|
||||
w_{t} = w_{t-1} * (1 - \lambda * \eta_{t-1}) - \eta_{t-1} * g_{t} \\
|
||||
ax_{t} = (w_t - ax_{t-1}) * \mu_{t-1} \\
|
||||
\eta_{t} = \frac{1.}{(1 + \lambda * lr * t)^\alpha} \\
|
||||
\mu_{t} = \frac{1}{\max(1, t - t0)}
|
||||
\end{gather*}
|
||||
|
||||
:math:`\lambda` represents the decay term, :math:`\mu` and :math:`\eta` are tracked to
|
||||
update :math:`ax` and :math:`w`, :math:`t0` represents the point of starting averaging,
|
||||
:math:`\alpha` represents the power for eta update, :math:`ax` represents the averaged
|
||||
parameter value, :math:`t` represents the current step, :math:`g` represents `gradients`,
|
||||
:math:`w` represents `params`.
|
||||
|
||||
Note:
|
||||
If parameters are not grouped, the `weight_decay` in optimizer will be applied on the parameters without 'beta'
|
||||
or 'gamma' in their names. Users can group parameters to change the strategy of decaying weight. When parameters
|
||||
are grouped, each group can set `weight_decay`, if not, the `weight_decay` in optimizer will be applied.
|
||||
|
||||
Args:
|
||||
params (Union[list[Parameter], list[dict]]): Must be list of `Parameter` or list of `dict`. When the
|
||||
`parameters` is a list of `dict`, the "params", "lr", "weight_decay", "grad_centralization" and
|
||||
"order_params" are the keys can be parsed.
|
||||
|
||||
- params: Required. Parameters in current group. The value must be a list of `Parameter`.
|
||||
|
||||
- lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
|
||||
If not, the `learning_rate` in optimizer will be used. Fixed and dynamic learning rate are supported.
|
||||
|
||||
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
|
||||
will be used. If not, the `weight_decay` in the optimizer will be used.
|
||||
|
||||
- grad_centralization: Optional. Must be Boolean. If "grad_centralization" is in the keys, the set value
|
||||
will be used. If not, the `grad_centralization` is False by default. This configuration only works on the
|
||||
convolution layer.
|
||||
|
||||
- order_params: Optional. When parameters is grouped, this usually is used to maintain the order of
|
||||
parameters that appeared in the network to improve performance. The value should be parameters whose
|
||||
order will be followed in optimizer.
|
||||
If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
|
||||
one group of `params`.
|
||||
|
||||
learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]):
|
||||
|
||||
- float: The fixed learning rate value. Must be equal to or greater than 0.
|
||||
|
||||
- int: The fixed learning rate value. Must be equal to or greater than 0. It will be converted to float.
|
||||
|
||||
- Tensor: Its value should be a scalar or a 1-D vector. For scalar, fixed learning rate will be applied.
|
||||
For vector, learning rate is dynamic, then the i-th step will take the i-th value as the learning rate.
|
||||
|
||||
- Iterable: Learning rate is dynamic. The i-th step will take the i-th value as the learning rate.
|
||||
|
||||
- LearningRateSchedule: Learning rate is dynamic. During training, the optimizer calls the instance of
|
||||
LearningRateSchedule with step as the input to get the learning rate of current step.
|
||||
|
||||
lambd (float): The decay term. Default: 1e-4.
|
||||
alpha (float): The power for eta update. Default: 0.75.
|
||||
t0 (float): The point of starting averaging. Default: 1e6.
|
||||
weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
|
||||
Outputs:
|
||||
Tensor[bool], the value is True.
|
||||
|
||||
Raises:
|
||||
TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.
|
||||
TypeError: If element of `parameters` is neither Parameter nor dict.
|
||||
TypeError: If `lambd`, `alpha` or `t0` is not a float.
|
||||
TypeError: If `weight_decay` is neither float nor int.
|
||||
ValueError: If `weight_decay` is less than 0.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> #1) All parameters use the same learning rate and weight decay
|
||||
>>> optim = nn.ASGD(params=net.trainable_params())
|
||||
>>>
|
||||
>>> #2) Use parameter groups and set different values
|
||||
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
|
||||
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
|
||||
>>> group_params = [{'params': conv_params,'grad_centralization':True},
|
||||
... {'params': no_conv_params, 'lr': 0.01},
|
||||
... {'order_params': net.trainable_params()}]
|
||||
>>> optim = nn.ASGD(group_params, learning_rate=0.1, weight_decay=0.0)
|
||||
>>> # The conv_params's parameters will use default learning rate of 0.1 default weight decay of 0.0 and grad
|
||||
>>> # centralization of True.
|
||||
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0 and grad
|
||||
>>> # centralization of False.
|
||||
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
|
||||
>>>
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim)
|
||||
"""
|
||||
|
||||
@opt_init_args_register
|
||||
def __init__(self, params, learning_rate=0.1, lambd=1e-4, alpha=0.75, t0=1e6, weight_decay=0.):
|
||||
|
||||
super(ASGD, self).__init__(learning_rate, params, weight_decay)
|
||||
|
||||
validator.check_value_type("lambd", lambd, [float], self.cls_name)
|
||||
validator.check_value_type("alpha", alpha, [float], self.cls_name)
|
||||
validator.check_value_type("t0", t0, [float], self.cls_name)
|
||||
|
||||
self.lambd = lambd
|
||||
self.alpha = alpha
|
||||
self.t0 = Tensor([t0], dtype=mstype.float32)
|
||||
mu, eta = [], []
|
||||
for param in self.parameters:
|
||||
mu.append(Parameter(Tensor(1., dtype=mstype.float32), name='mu_'+param.name))
|
||||
eta.append(Parameter(Tensor(0., dtype=mstype.float32), name='eta_'+param.name))
|
||||
self.lens = len(self.parameters)
|
||||
self.mu = mindspore.ParameterTuple(mu)
|
||||
self.eta = mindspore.ParameterTuple(eta)
|
||||
self.step = Parameter(Tensor(1., dtype=mstype.float32), name='step')
|
||||
self.ax = self.parameters.clone(prefix="ax_", init='zeros')
|
||||
self.pow = P.Pow()
|
||||
self.maximum = P.Maximum()
|
||||
self.assign = P.Assign()
|
||||
self.assignadd = P.AssignAdd()
|
||||
self.assignsub = P.AssignSub()
|
||||
self.cast = P.Cast()
|
||||
self.squeeze = P.Squeeze()
|
||||
|
||||
def construct(self, gradients):
|
||||
gradients = self.decay_weight(gradients)
|
||||
gradients = self.gradients_centralization(gradients)
|
||||
gradients = self.scale_grad(gradients)
|
||||
lrs = self.get_lr()
|
||||
success = True
|
||||
|
||||
for index, (grad, param, mu, eta, ax) in enumerate(zip(gradients, self.parameters, self.mu, self.eta, self.ax)):
|
||||
lr = lrs[index] if self.is_group_lr else lrs
|
||||
|
||||
if self.step == 1.:
|
||||
self.assign(eta, lr)
|
||||
|
||||
param_fp32 = self.cast(param, mstype.float32)
|
||||
gradient_fp32 = self.cast(grad, mstype.float32)
|
||||
ax_fp32 = self.cast(ax, mstype.float32)
|
||||
param_fp32 = param_fp32 * (1. - self.lambd * eta) - eta * gradient_fp32
|
||||
|
||||
self.assign(param, self.cast(param_fp32, param.dtype))
|
||||
|
||||
if mu != 1:
|
||||
self.assignadd(ax, self.cast((param_fp32 - ax_fp32) * mu, ax.dtype))
|
||||
else:
|
||||
self.assign(ax, param)
|
||||
|
||||
self.assign(eta, lr / (self.pow((1. + (self.lambd * lr * self.step)), self.alpha)))
|
||||
self.assign(mu, 1. / self.squeeze(self.maximum(1., self.step - self.t0)))
|
||||
|
||||
success = F.depend(success, self.assignadd(self.step, 1.))
|
||||
return success
|
|
@ -0,0 +1,215 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""rprop"""
|
||||
from mindspore import ops
|
||||
from mindspore.ops import functional as F, operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from .optimizer import Optimizer
|
||||
from .optimizer import opt_init_args_register
|
||||
|
||||
class Rprop(Optimizer):
|
||||
r"""
|
||||
Implements Resilient backpropagation.
|
||||
|
||||
Further information about this implementation can be found at `A Direct Adaptive Method for Faster Backpropagation
|
||||
Learning: The RPROP Algorithm <http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.21.1417>`_.
|
||||
|
||||
The updating formulas are as follows:
|
||||
|
||||
.. math::
|
||||
\begin{gather*}
|
||||
&\hspace{0mm} \textbf{if} \: g_{t-1} g_t > 0 \\
|
||||
&\hspace{5mm} \Delta_t \leftarrow \mathrm{min}(\Delta_{t-1} \eta_{+}, \Delta_{max}) \\
|
||||
&\hspace{0mm} \textbf{else if} \: g_{t-1} g_t < 0 \\
|
||||
&\hspace{5mm} \Delta_t \leftarrow \mathrm{max}(\Delta_{t-1} \eta_{-}, \Delta_{min}) \\
|
||||
&\hspace{mm} \textbf{else} \: \\
|
||||
&\hspace{5mm} \Delta_t \leftarrow \Delta_{t-1} \\
|
||||
&\hspace{0mm} w_{t} \leftarrow w_{t-1}- \Delta_{t} \mathrm{sign}(g_t) \\
|
||||
\end{gather*}
|
||||
|
||||
:math:`\Delta_{min/max}` represents the min/max step size, :math:`\eta_{+/-}` represents the factors of
|
||||
etaminus and etaplus, :math:`g` represents `gradients`, :math:`w` represents `parameters`.
|
||||
|
||||
Note:
|
||||
If parameters are not grouped, the `weight_decay` in optimizer will be applied on the parameters without 'beta'
|
||||
or 'gamma' in their names. Users can group parameters to change the strategy of decaying weight. When parameters
|
||||
are grouped, each group can set `weight_decay`, if not, the `weight_decay` in optimizer will be applied.
|
||||
|
||||
Args:
|
||||
params (Union[list[Parameter], list[dict]]): Must be list of `Parameter` or list of `dict`. When the
|
||||
`parameters` is a list of `dict`, the "params", "lr", "weight_decay", "grad_centralization" and
|
||||
"order_params" are the keys can be parsed.
|
||||
|
||||
- params: Required. Parameters in current group. The value must be a list of `Parameter`.
|
||||
|
||||
- lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
|
||||
If not, the `learning_rate` in optimizer will be used. Fixed and dynamic learning rate are supported.
|
||||
|
||||
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
|
||||
will be used. If not, the `weight_decay` in the optimizer will be used.
|
||||
|
||||
- grad_centralization: Optional. Must be Boolean. If "grad_centralization" is in the keys, the set value
|
||||
will be used. If not, the `grad_centralization` is False by default. This configuration only works on the
|
||||
convolution layer.
|
||||
|
||||
- order_params: Optional. When parameters is grouped, this usually is used to maintain the order of
|
||||
parameters that appeared in the network to improve performance. The value should be parameters whose
|
||||
order will be followed in optimizer.
|
||||
If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
|
||||
one group of `params`.
|
||||
|
||||
learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]):
|
||||
|
||||
- float: The fixed learning rate value. Must be equal to or greater than 0.
|
||||
|
||||
- int: The fixed learning rate value. Must be equal to or greater than 0. It will be converted to float.
|
||||
|
||||
- Tensor: Its value should be a scalar or a 1-D vector. For scalar, fixed learning rate will be applied.
|
||||
For vector, learning rate is dynamic, then the i-th step will take the i-th value as the learning rate.
|
||||
|
||||
- Iterable: Learning rate is dynamic. The i-th step will take the i-th value as the learning rate.
|
||||
|
||||
- LearningRateSchedule: Learning rate is dynamic. During training, the optimizer calls the instance of
|
||||
LearningRateSchedule with step as the input to get the learning rate of current step.
|
||||
|
||||
etas (tuple[float, float]): The factor of multiplicative increasing or
|
||||
descreasing(etaminus, etaplus).
|
||||
step_sizes(tuple[float, float]): The allowed minimal and maximal step size(min_step_sizes, max_step_size).
|
||||
weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
|
||||
Outputs:
|
||||
Tensor[bool], the value is True.
|
||||
|
||||
Raises:
|
||||
TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.
|
||||
TypeError: If element of `parameters` is neither Parameter nor dict.
|
||||
TypeError: If `step_sizes` or `etas` is not a tuple.
|
||||
ValueError: If maximal step size is less than minimal step size.
|
||||
ValueError: If the length of `step_sizes` or `ets` is not equal to 2.
|
||||
TypeError: If the element in `etas` or `step_sizes` is not a float.
|
||||
ValueError: If `etaminus` is not in the range of (0, 1) or `etaplus` is not greater than 1.
|
||||
TypeError: If `weight_decay` is neither float nor int.
|
||||
ValueError: If `weight_decay` is less than 0.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> #1) All parameters use the same learning rate and weight decay
|
||||
>>> optim = nn.Rprop(params=net.trainable_params())
|
||||
>>>
|
||||
>>> #2) Use parameter groups and set different values
|
||||
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
|
||||
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
|
||||
>>> group_params = [{'params': conv_params,'grad_centralization':True},
|
||||
... {'params': no_conv_params, 'lr': 0.01},
|
||||
... {'order_params': net.trainable_params()}]
|
||||
>>> optim = nn.Rprop(group_params, learning_rate=0.1, weight_decay=0.0)
|
||||
>>> # The conv_params's parameters will use default learning rate of 0.1 default weight decay of 0.0 and grad
|
||||
>>> # centralization of True.
|
||||
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0 and grad
|
||||
>>> # centralization of False.
|
||||
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
|
||||
>>>
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim)
|
||||
"""
|
||||
|
||||
@opt_init_args_register
|
||||
def __init__(self, params, learning_rate=0.1, etas=(0.5, 1.2), step_sizes=(1e-6, 50.), weight_decay=0.1):
|
||||
|
||||
super(Rprop, self).__init__(learning_rate, params, weight_decay)
|
||||
if not isinstance(etas, tuple):
|
||||
raise TypeError("For Rprop, etas should be a tuple, but got {}.".format(type(etas)))
|
||||
if len(etas) != 2:
|
||||
raise ValueError("For Rprop, etas should be a tuple with the size of 2, but got {}.".format(len(etas)))
|
||||
|
||||
if not isinstance(step_sizes, tuple):
|
||||
raise TypeError("For Rprop, step_sizes should be a tuple, but got {}.".format(type(etas)))
|
||||
if len(step_sizes) != 2:
|
||||
raise ValueError("For Rprop, step_sizes should be a tuple with the size of 2, "
|
||||
"but got {}.".format(len(step_sizes)))
|
||||
|
||||
if step_sizes[0] > step_sizes[1]:
|
||||
raise ValueError("For Rprop, maximal step size should not be less than minimal step size, "
|
||||
"but got {} > {}.".format(step_sizes[0], step_sizes[1]))
|
||||
|
||||
validator.check_float_range(etas[0], 0.0, 1.0, Rel.INC_NEITHER, "etaminus", self.cls_name)
|
||||
validator.check_value_type("etaplus", etas[1], [float], self.cls_name)
|
||||
if etas[1] <= 1.0:
|
||||
raise ValueError("For Rprop, etaplus should be greater than 1.0, but got etaplus {}.".format(etas[1]))
|
||||
|
||||
validator.check_value_type("min_step_sizes", step_sizes[0], [float], self.cls_name)
|
||||
validator.check_value_type("max_step_sizes", step_sizes[1], [float], self.cls_name)
|
||||
|
||||
self.etaminus, self.etaplus = etas
|
||||
self.step_size_min, self.step_size_max = step_sizes
|
||||
self.prev = self.parameters.clone(prefix="prev", init='zeros')
|
||||
self.step_size = self.parameters.clone(prefix="step_size", init='zeros')
|
||||
self.step = Parameter(Tensor(0., dtype=mstype.float32), name='step')
|
||||
|
||||
self.fill = P.Fill()
|
||||
self.sign = P.Sign()
|
||||
self.assign = P.Assign()
|
||||
self.assignadd = P.AssignAdd()
|
||||
self.cast = P.Cast()
|
||||
self.select = P.Select()
|
||||
self.ones_like = P.OnesLike()
|
||||
|
||||
def construct(self, gradients):
|
||||
gradients = self.decay_weight(gradients)
|
||||
gradients = self.gradients_centralization(gradients)
|
||||
gradients = self.scale_grad(gradients)
|
||||
lrs = self.get_lr()
|
||||
success = True
|
||||
|
||||
for index, (grad, param, prev, step_size) in enumerate(zip(gradients, self.parameters,
|
||||
self.prev, self.step_size)):
|
||||
lr = lrs[index] if self.is_group_lr else lrs
|
||||
|
||||
if self.step == 0.:
|
||||
step_size_fp32 = self.ones_like(step_size) * lr
|
||||
else:
|
||||
step_size_fp32 = self.cast(step_size, mstype.float32)
|
||||
|
||||
gradient_fp32 = self.cast(grad, mstype.float32)
|
||||
param_fp32 = self.cast(param, mstype.float32)
|
||||
|
||||
sign = self.sign(gradient_fp32 * prev)
|
||||
sign = self.select(sign > 0, self.fill(mstype.float32, sign.shape, self.etaplus), sign)
|
||||
sign = self.select(sign < 0, self.fill(mstype.float32, sign.shape, self.etaminus), sign)
|
||||
sign = self.select(sign == 0, self.fill(mstype.float32, sign.shape, 1.), sign)
|
||||
|
||||
step_size_fp32 = ops.clip_by_value(step_size_fp32 * sign, self.step_size_min, self.step_size_max)
|
||||
|
||||
gradient_update = self.select(sign == self.etaminus, self.fill(mstype.float32, sign.shape, 0.),
|
||||
gradient_fp32)
|
||||
next_param = param_fp32 - self.sign(gradient_update) * step_size_fp32
|
||||
|
||||
self.assign(param, self.cast(next_param, param.dtype))
|
||||
self.assign(prev, self.cast(gradient_update, prev.dtype))
|
||||
self.assign(step_size, self.cast(step_size_fp32, step_size.dtype))
|
||||
|
||||
success = F.depend(success, self.assignadd(self.step, 1.))
|
||||
|
||||
return success
|
|
@ -0,0 +1,208 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import mindspore
|
||||
from mindspore import nn, Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn.optim import ASGD
|
||||
from mindspore.nn.optim import Rprop
|
||||
np.random.seed(1024)
|
||||
|
||||
fc1_weight = np.array([[0.72346634, 0.95608497, 0.4084163, 0.18627149,
|
||||
0.6942514, 0.39767185, 0.24918061, 0.4548748],
|
||||
[0.7203382, 0.19086994, 0.76286614, 0.87920564,
|
||||
0.3169892, 0.9462494, 0.62827677, 0.27504718],
|
||||
[0.3544535, 0.2524781, 0.5370583, 0.8313121,
|
||||
0.6670143, 0.0488653, 0.62225235, 0.7546456],
|
||||
[0.17985944, 0.05106374, 0.31064633, 0.4863033,
|
||||
0.848814, 0.5523157, 0.20295663, 0.7213356]]).astype("float32")
|
||||
|
||||
fc1_bias = np.array([0.79708564, 0.13728078, 0.66322654, 0.88128525]).astype("float32")
|
||||
|
||||
fc2_weight = np.array([[0.8473515, 0.50923985, 0.42287776, 0.29769543]]).astype("float32")
|
||||
|
||||
fc2_bias = np.array([0.09996348]).astype("float32")
|
||||
|
||||
|
||||
def make_fake_data():
|
||||
"""
|
||||
make fake data
|
||||
"""
|
||||
data, label = [], []
|
||||
for i in range(20):
|
||||
data.append(mindspore.Tensor(np.array(np.ones((2, 8)) * i, dtype=np.float32)))
|
||||
label.append(mindspore.Tensor(np.array(np.ones((2, 1)) * (i + 1), dtype=np.float32)))
|
||||
return data, label
|
||||
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
"""
|
||||
build net with loss
|
||||
"""
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.network = network
|
||||
self.loss = nn.MSELoss(reduction='sum')
|
||||
|
||||
def construct(self, x, label):
|
||||
out = self.network(x)
|
||||
loss = self.loss(out, label)
|
||||
return loss
|
||||
|
||||
|
||||
class FakeNet(nn.Cell):
|
||||
"""
|
||||
build fake net
|
||||
"""
|
||||
def __init__(self):
|
||||
super(FakeNet, self).__init__()
|
||||
self.fc1 = nn.Dense(in_channels=8, out_channels=4, weight_init=Tensor(fc1_weight), bias_init=Tensor(fc1_bias))
|
||||
self.fc2 = nn.Dense(in_channels=4, out_channels=1, weight_init=Tensor(fc2_weight), bias_init=Tensor(fc2_bias))
|
||||
self.relu = nn.ReLU()
|
||||
self.reducemean = P.ReduceMean()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.relu(self.fc1(x))
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
def _initialize_weights(self):
|
||||
"""
|
||||
parameter initialization
|
||||
"""
|
||||
self.init_parameters_data()
|
||||
for name, m in self.cells_and_names():
|
||||
if name == 'fc1':
|
||||
m.weight.set_data(Tensor(fc1_weight))
|
||||
m.bias.set_data(Tensor(fc1_bias))
|
||||
elif name == 'fc2':
|
||||
m.weight.set_data(Tensor(fc2_weight))
|
||||
m.bias.set_data(Tensor(fc2_bias))
|
||||
|
||||
|
||||
def build_network(opt_config, is_group=False):
|
||||
"""
|
||||
Construct training
|
||||
"""
|
||||
losses = []
|
||||
net = FakeNet()
|
||||
|
||||
networkwithloss = NetWithLoss(net)
|
||||
networkwithloss.set_train()
|
||||
|
||||
if is_group:
|
||||
fc1_params = list(filter(lambda x: 'fc1' in x.name, networkwithloss.trainable_params()))
|
||||
fc2_params = list(filter(lambda x: 'fc1' not in x.name, networkwithloss.trainable_params()))
|
||||
if opt_config['name'] == 'ASGD':
|
||||
params = [{'params': fc1_params, 'weight_decay': 0.01, 'lr': 0.001}, {'params': fc2_params, 'lr': 0.1}]
|
||||
else:
|
||||
params = [{'params': fc1_params, 'lr': 0.001}, {'params': fc2_params, 'lr': 0.1}]
|
||||
else:
|
||||
params = networkwithloss.trainable_params()
|
||||
|
||||
if opt_config['name'] == 'ASGD':
|
||||
net_opt = ASGD(params, learning_rate=opt_config['lr'], lambd=opt_config['lambd'], alpha=opt_config['alpha'],
|
||||
t0=opt_config['t0'], weight_decay=opt_config['weight_decay'])
|
||||
|
||||
elif opt_config['name'] == 'Rprop':
|
||||
net_opt = Rprop(params, learning_rate=opt_config['lr'], etas=opt_config['etas'],
|
||||
step_sizes=opt_config['step_sizes'], weight_decay=0.0)
|
||||
|
||||
trainonestepcell = mindspore.nn.TrainOneStepCell(networkwithloss, net_opt)
|
||||
data, label = make_fake_data()
|
||||
for i in range(20):
|
||||
loss = trainonestepcell(data[i], label[i])
|
||||
losses.append(loss.asnumpy())
|
||||
if opt_config['name'] == 'ASGD':
|
||||
return np.array(losses), net_opt
|
||||
return np.array(losses)
|
||||
|
||||
|
||||
loss_default_asgd = np.array([3.01246792e-01, 1.20041794e+02, 1.38681079e+03, 2.01250820e+01,
|
||||
3.27283554e+01, 4.76963005e+01, 6.47094269e+01, 8.34786530e+01,
|
||||
1.03742706e+02, 1.25265739e+02, 1.47835190e+02, 1.71259613e+02,
|
||||
1.95367035e+02, 2.20003204e+02, 2.45029831e+02, 2.70323456e+02,
|
||||
2.95774048e+02, 3.21283752e+02, 3.46765594e+02, 3.72143097e+02], dtype=np.float32)
|
||||
|
||||
loss_not_default_asgd = np.array([3.01246792e-01, 1.26019104e+02, 1.90600449e+02, 9.70605755e+00,
|
||||
2.98419113e+01, 3.68430023e+02, 1.06318066e+04, 1.35017746e+02,
|
||||
1.68673813e+02, 2.05914215e+02, 2.46694992e+02, 2.90972443e+02,
|
||||
3.38703430e+02, 3.89845123e+02, 4.44355103e+02, 5.02191406e+02,
|
||||
5.63312500e+02, 6.27676941e+02, 6.95244202e+02, 7.65973816e+02], dtype=np.float32)
|
||||
|
||||
loss_group_asgd = np.array([3.01246792e-01, 7.26708527e+01, 2.84905312e+05, 4.17499258e+04,
|
||||
1.46797949e+04, 5.07966602e+03, 1.70935132e+03, 5.47094910e+02,
|
||||
1.59216995e+02, 3.78818207e+01, 5.18196869e+00, 2.62275129e-03,
|
||||
2.09768796e+00, 5.23108435e+00, 7.78943682e+00, 9.57108879e+00,
|
||||
1.07310610e+01, 1.14618425e+01, 1.19147835e+01, 1.21936722e+01], dtype=np.float32)
|
||||
|
||||
|
||||
loss_default_rprop = np.array([3.01246792e-01, 1.19871742e+02, 4.13467163e+02, 8.09146179e+02,
|
||||
1.22364807e+03, 1.56787573e+03, 1.75733594e+03, 1.72866272e+03,
|
||||
1.46183936e+03, 1.00406335e+03, 4.84076874e+02, 9.49734650e+01,
|
||||
2.00592804e+01, 1.87920704e+01, 1.53733969e+01, 1.85836582e+01,
|
||||
5.21527790e-02, 2.01522671e-02, 7.19913816e+00, 8.52459526e+00], dtype=np.float32)
|
||||
|
||||
loss_not_default_rprop = np.array([3.0124679e-01, 1.2600269e+02, 4.7351608e+02, 1.0220379e+03,
|
||||
1.7181555e+03, 2.4367019e+03, 2.9170872e+03, 2.7243464e+03,
|
||||
1.4999669e+03, 7.5820435e+01, 1.0590715e+03, 5.4336096e+02,
|
||||
7.0162407e+01, 8.2754419e+02, 9.6329260e+02, 3.4475109e+01,
|
||||
5.3843134e+02, 6.0064526e+02, 1.1046149e+02, 3.5530117e+03], dtype=np.float32)
|
||||
|
||||
loss_group_rprop = np.array([3.0124679e-01, 7.1360558e+01, 4.8910957e+01, 2.1730331e+02,
|
||||
3.0747052e+02, 5.2734237e+00, 5.6865869e+00, 1.7116127e+02,
|
||||
2.0539343e+02, 2.2993685e+01, 2.6194101e+02, 2.8772815e+02,
|
||||
2.4236647e+01, 3.9299741e+02, 3.5600668e+02, 1.4759110e+01,
|
||||
7.2244568e+02, 8.1952783e+02, 9.8913864e+01, 1.1141744e+03], dtype=np.float32)
|
||||
|
||||
|
||||
default_fc1_weight_asgd = np.array([[-0.9451941, -0.71258026, -1.2602371, -1.4823773,
|
||||
-0.974408, -1.2709816, -1.4194703, -1.2137808],
|
||||
[-1.5341775, -2.0636342, -1.4916497, -1.3753126,
|
||||
-1.9375193, -1.308271, -1.6262367, -1.9794592],
|
||||
[-1.9886293, -2.0906024, -1.8060291, -1.5117803,
|
||||
-1.6760755, -2.2942104, -1.7208353, -1.5884445],
|
||||
[-2.071215, -2.2000103, -1.9404325, -1.7647781,
|
||||
-1.4022746, -1.6987679, -2.0481179, -1.5297506]], dtype=np.float32)
|
||||
default_fc1_bias_asgd = np.array([-0.17978168, -1.0764512, -0.578816, -0.2928958], dtype=np.float32)
|
||||
default_fc2_weight_asgd = np.array([[4.097412, 6.2694297, 5.9203916, 5.3845487]], dtype=np.float32)
|
||||
default_fc2_bias_asgd = np.array([6.904814], dtype=np.float32)
|
||||
|
||||
|
||||
no_default_fc1_weight_asgd = np.array([[-1.3406217, -1.1080127, -1.655658, -1.8777936,
|
||||
-1.3698348, -1.6664025, -1.8148884, -1.6092018],
|
||||
[-1.1475986, -1.6770473, -1.1050745, -0.98873824,
|
||||
-1.5509329, -0.9216978, -1.2396574, -1.5928726],
|
||||
[-1.2329121, -1.334883, -1.050313, -0.756071,
|
||||
-0.92036265, -1.5384867, -0.96512324, -0.8327349],
|
||||
[-1.0685704, -1.1973612, -0.9377885, -0.7621386,
|
||||
-0.39964262, -0.69612867, -1.0454736, -0.52711576]], dtype=np.float32)
|
||||
no_default_fc1_bias_asgd = np.array([0.41264832, -0.19961096, 0.37743938, 0.65807366], dtype=np.float32)
|
||||
no_default_fc2_weight_asgd = np.array([[-5.660916, -5.9415145, -5.1402636, -4.199707]], dtype=np.float32)
|
||||
no_default_fc2_bias_asgd = np.array([0.5082278], dtype=np.float32)
|
||||
|
||||
|
||||
no_default_group_fc1_weight_asgd = np.array([[-32.526627, -32.29401, -32.8416, -33.06367, -32.55584,
|
||||
-32.852345, -33.000767, -32.795143],
|
||||
[-33.164936, -33.69432, -33.12241, -33.006073, -33.568207,
|
||||
-32.9391, -33.256996, -33.61015],
|
||||
[-33.118973, -33.220943, -32.936436, -32.642193, -32.806488,
|
||||
-33.424484, -32.85125, -32.718857],
|
||||
[-30.155754, -30.284513, -30.025005, -29.849358, -29.486917,
|
||||
-29.783375, -30.132658, -29.614393]], dtype=np.float32)
|
||||
no_default_group_fc1_bias_asgd = np.array([-15.838092, -16.811989, -16.078112, -14.289094], dtype=np.float32)
|
||||
no_default_group_fc2_weight_asgd = np.array([[1288.7146, 1399.3041, 1292.8445, 1121.4629]], dtype=np.float32)
|
||||
no_default_group_fc2_bias_asgd = np.array([18.513494], dtype=np.float32)
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
from .optimizer_utils import build_network, loss_not_default_asgd, loss_default_asgd, loss_group_asgd
|
||||
|
||||
|
||||
def test_default_asgd_pynative():
|
||||
"""
|
||||
Feature: Test ASGD optimizer
|
||||
Description: Test ASGD in Pynative mode with default parameter
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
from .optimizer_utils import default_fc1_weight_asgd, \
|
||||
default_fc1_bias_asgd, default_fc2_weight_asgd, default_fc2_bias_asgd
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='Ascend')
|
||||
config = {'name': 'ASGD', 'lr': 0.01, 'lambd': 1e-4, 'alpha': 0.75, 't0': 1e6, 'weight_decay': 0.0}
|
||||
loss, cells = build_network(config)
|
||||
assert np.allclose(cells.ax[0].asnumpy(), default_fc1_weight_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[1].asnumpy(), default_fc1_bias_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[2].asnumpy(), default_fc2_weight_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[3].asnumpy(), default_fc2_bias_asgd, atol=1.e-5)
|
||||
assert np.allclose(loss_default_asgd, loss, atol=1.e-5)
|
||||
|
||||
|
||||
def test_default_asgd_graph():
|
||||
"""
|
||||
Feature: Test ASGD optimizer
|
||||
Description: Test ASGD in Graph mode with default parameter
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
from .optimizer_utils import default_fc1_weight_asgd, \
|
||||
default_fc1_bias_asgd, default_fc2_weight_asgd, default_fc2_bias_asgd
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
|
||||
config = {'name': 'ASGD', 'lr': 0.01, 'lambd': 1e-4, 'alpha': 0.75, 't0': 1e6, 'weight_decay': 0.0}
|
||||
loss, cells = build_network(config)
|
||||
assert np.allclose(cells.ax[0].asnumpy(), default_fc1_weight_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[1].asnumpy(), default_fc1_bias_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[2].asnumpy(), default_fc2_weight_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[3].asnumpy(), default_fc2_bias_asgd, atol=1.e-5)
|
||||
assert np.allclose(loss_default_asgd, loss, atol=1.e-5)
|
||||
|
||||
|
||||
def test_no_default_asgd_pynative():
|
||||
"""
|
||||
Feature: Test ASGD optimizer
|
||||
Description: Test ASGD in Pynative mode with another set of parameter
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
from .optimizer_utils import no_default_fc1_weight_asgd, \
|
||||
no_default_fc1_bias_asgd, no_default_fc2_weight_asgd, no_default_fc2_bias_asgd
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='Ascend')
|
||||
config = {'name': 'ASGD', 'lr': 0.001, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001}
|
||||
loss, cells = build_network(config)
|
||||
assert np.allclose(cells.ax[0].asnumpy(), no_default_fc1_weight_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[1].asnumpy(), no_default_fc1_bias_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[2].asnumpy(), no_default_fc2_weight_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[3].asnumpy(), no_default_fc2_bias_asgd, atol=1.e-5)
|
||||
assert np.allclose(loss_not_default_asgd, loss, atol=1.e-5, rtol=1e-3)
|
||||
|
||||
|
||||
def test_no_default_asgd_graph():
|
||||
"""
|
||||
Feature: Test ASGD optimizer
|
||||
Description: Test ASGD in Graph mode with another set of parameter
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
from .optimizer_utils import no_default_fc1_weight_asgd, \
|
||||
no_default_fc1_bias_asgd, no_default_fc2_weight_asgd, no_default_fc2_bias_asgd
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
|
||||
config = {'name': 'ASGD', 'lr': 0.001, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001}
|
||||
loss, cells = build_network(config)
|
||||
assert np.allclose(cells.ax[0].asnumpy(), no_default_fc1_weight_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[1].asnumpy(), no_default_fc1_bias_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[2].asnumpy(), no_default_fc2_weight_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[3].asnumpy(), no_default_fc2_bias_asgd, atol=1.e-5)
|
||||
assert np.allclose(loss_not_default_asgd, loss, atol=1.e-5, rtol=1e-3)
|
||||
|
||||
|
||||
def test_default_asgd_group_pynative():
|
||||
"""
|
||||
Feature: Test ASGD optimizer
|
||||
Description: Test ASGD in Pynative mode with parameter grouping
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
from .optimizer_utils import no_default_group_fc1_weight_asgd, no_default_group_fc1_bias_asgd, \
|
||||
no_default_group_fc2_weight_asgd, no_default_group_fc2_bias_asgd
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='Ascend')
|
||||
config = {'name': 'ASGD', 'lr': 0.1, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001}
|
||||
loss, cells = build_network(config, is_group=True)
|
||||
assert np.allclose(cells.ax[0].asnumpy(), no_default_group_fc1_weight_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[1].asnumpy(), no_default_group_fc1_bias_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[2].asnumpy(), no_default_group_fc2_weight_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[3].asnumpy(), no_default_group_fc2_bias_asgd, atol=1.e-5)
|
||||
assert np.allclose(loss_group_asgd, loss, atol=1.e-5, rtol=1e-3)
|
||||
|
||||
|
||||
def test_default_asgd_group_graph():
|
||||
"""
|
||||
Feature: Test ASGD optimizer
|
||||
Description: Test ASGD in Graph mode with parameter grouping
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
from .optimizer_utils import no_default_group_fc1_weight_asgd, no_default_group_fc1_bias_asgd, \
|
||||
no_default_group_fc2_weight_asgd, no_default_group_fc2_bias_asgd
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
|
||||
config = {'name': 'ASGD', 'lr': 0.1, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001}
|
||||
loss, cells = build_network(config, is_group=True)
|
||||
assert np.allclose(cells.ax[0].asnumpy(), no_default_group_fc1_weight_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[1].asnumpy(), no_default_group_fc1_bias_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[2].asnumpy(), no_default_group_fc2_weight_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[3].asnumpy(), no_default_group_fc2_bias_asgd, atol=1.e-5)
|
||||
assert np.allclose(loss_group_asgd, loss, atol=1.e-5, rtol=1e-3)
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
from .optimizer_utils import build_network, loss_not_default_asgd, loss_default_asgd, loss_group_asgd
|
||||
|
||||
|
||||
def test_default_asgd_graph():
|
||||
"""
|
||||
Feature: Test ASGD optimizer
|
||||
Description: Test ASGD in Graph mode with default parameter
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
from .optimizer_utils import default_fc1_weight_asgd, \
|
||||
default_fc1_bias_asgd, default_fc2_weight_asgd, default_fc2_bias_asgd
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
config = {'name': 'ASGD', 'lr': 0.01, 'lambd': 1e-4, 'alpha': 0.75, 't0': 1e6, 'weight_decay': 0.0}
|
||||
loss, cells = build_network(config)
|
||||
assert np.allclose(cells.ax[0].asnumpy(), default_fc1_weight_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[1].asnumpy(), default_fc1_bias_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[2].asnumpy(), default_fc2_weight_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[3].asnumpy(), default_fc2_bias_asgd, atol=1.e-5)
|
||||
assert np.allclose(loss_default_asgd, loss, atol=1.e-5)
|
||||
|
||||
|
||||
def test_no_default_asgd_graph():
|
||||
"""
|
||||
Feature: Test ASGD optimizer
|
||||
Description: Test ASGD in Graph mode with another set of parameter
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
from .optimizer_utils import no_default_fc1_weight_asgd, \
|
||||
no_default_fc1_bias_asgd, no_default_fc2_weight_asgd, no_default_fc2_bias_asgd
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
config = {'name': 'ASGD', 'lr': 0.001, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001}
|
||||
loss, cells = build_network(config, is_group=True)
|
||||
assert np.allclose(cells.ax[0].asnumpy(), no_default_fc1_weight_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[1].asnumpy(), no_default_fc1_bias_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[2].asnumpy(), no_default_fc2_weight_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[3].asnumpy(), no_default_fc2_bias_asgd, atol=1.e-5)
|
||||
assert np.allclose(loss_not_default_asgd, loss, atol=1.e-5, rtol=1e-3)
|
||||
|
||||
|
||||
|
||||
def test_default_asgd_group_graph():
|
||||
"""
|
||||
Feature: Test ASGD optimizer
|
||||
Description: Test ASGD in Graph mode with parameter grouping
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
from .optimizer_utils import no_default_group_fc1_weight_asgd, no_default_group_fc1_bias_asgd, \
|
||||
no_default_group_fc2_weight_asgd, no_default_group_fc2_bias_asgd
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
config = {'name': 'ASGD', 'lr': 0.1, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001}
|
||||
loss, cells = build_network(config, is_group=True)
|
||||
assert np.allclose(cells.ax[0].asnumpy(), no_default_group_fc1_weight_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[1].asnumpy(), no_default_group_fc1_bias_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[2].asnumpy(), no_default_group_fc2_weight_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[3].asnumpy(), no_default_group_fc2_bias_asgd, atol=1.e-5)
|
||||
assert np.allclose(loss_group_asgd, loss, atol=1.e-5, rtol=1e-3)
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
from .optimizer_utils import build_network, loss_not_default_asgd, loss_default_asgd, loss_group_asgd
|
||||
|
||||
|
||||
def test_default_asgd_pynative():
|
||||
"""
|
||||
Feature: Test ASGD optimizer
|
||||
Description: Test ASGD in Pynative mode with default parameter
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
from .optimizer_utils import default_fc1_weight_asgd, \
|
||||
default_fc1_bias_asgd, default_fc2_weight_asgd, default_fc2_bias_asgd
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
config = {'name': 'ASGD', 'lr': 0.01, 'lambd': 1e-4, 'alpha': 0.75, 't0': 1e6, 'weight_decay': 0.0}
|
||||
loss, cells = build_network(config)
|
||||
assert np.allclose(cells.ax[0].asnumpy(), default_fc1_weight_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[1].asnumpy(), default_fc1_bias_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[2].asnumpy(), default_fc2_weight_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[3].asnumpy(), default_fc2_bias_asgd, atol=1.e-5)
|
||||
assert np.allclose(loss_default_asgd, loss, atol=1.e-5)
|
||||
|
||||
|
||||
def test_default_asgd_graph():
|
||||
"""
|
||||
Feature: Test ASGD optimizer
|
||||
Description: Test ASGD in Graph mode with default parameter
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
from .optimizer_utils import default_fc1_weight_asgd, \
|
||||
default_fc1_bias_asgd, default_fc2_weight_asgd, default_fc2_bias_asgd
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
config = {'name': 'ASGD', 'lr': 0.01, 'lambd': 1e-4, 'alpha': 0.75, 't0': 1e6, 'weight_decay': 0.0}
|
||||
loss, cells = build_network(config)
|
||||
assert np.allclose(cells.ax[0].asnumpy(), default_fc1_weight_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[1].asnumpy(), default_fc1_bias_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[2].asnumpy(), default_fc2_weight_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[3].asnumpy(), default_fc2_bias_asgd, atol=1.e-5)
|
||||
assert np.allclose(loss_default_asgd, loss, atol=1.e-5)
|
||||
|
||||
|
||||
def test_no_default_asgd_pynative():
|
||||
"""
|
||||
Feature: Test ASGD optimizer
|
||||
Description: Test ASGD in Pynative mode with another set of parameter
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
from .optimizer_utils import no_default_fc1_weight_asgd, \
|
||||
no_default_fc1_bias_asgd, no_default_fc2_weight_asgd, no_default_fc2_bias_asgd
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
config = {'name': 'ASGD', 'lr': 0.001, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001}
|
||||
loss, cells = build_network(config)
|
||||
assert np.allclose(cells.ax[0].asnumpy(), no_default_fc1_weight_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[1].asnumpy(), no_default_fc1_bias_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[2].asnumpy(), no_default_fc2_weight_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[3].asnumpy(), no_default_fc2_bias_asgd, atol=1.e-5)
|
||||
assert np.allclose(loss_not_default_asgd, loss, atol=1.e-5, rtol=1e-3)
|
||||
|
||||
|
||||
def test_no_default_asgd_graph():
|
||||
"""
|
||||
Feature: Test ASGD optimizer
|
||||
Description: Test ASGD in Graph mode with another set of parameter
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
from .optimizer_utils import no_default_fc1_weight_asgd, \
|
||||
no_default_fc1_bias_asgd, no_default_fc2_weight_asgd, no_default_fc2_bias_asgd
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
config = {'name': 'ASGD', 'lr': 0.001, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001}
|
||||
loss, cells = build_network(config)
|
||||
assert np.allclose(cells.ax[0].asnumpy(), no_default_fc1_weight_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[1].asnumpy(), no_default_fc1_bias_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[2].asnumpy(), no_default_fc2_weight_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[3].asnumpy(), no_default_fc2_bias_asgd, atol=1.e-5)
|
||||
assert np.allclose(loss_not_default_asgd, loss, atol=1.e-5, rtol=1e-3)
|
||||
|
||||
|
||||
def test_default_asgd_group_pynative():
|
||||
"""
|
||||
Feature: Test ASGD optimizer
|
||||
Description: Test ASGD in Pynative mode with parameter grouping
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
from .optimizer_utils import no_default_group_fc1_weight_asgd, no_default_group_fc1_bias_asgd, \
|
||||
no_default_group_fc2_weight_asgd, no_default_group_fc2_bias_asgd
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
config = {'name': 'ASGD', 'lr': 0.1, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001}
|
||||
loss, cells = build_network(config, is_group=True)
|
||||
assert np.allclose(cells.ax[0].asnumpy(), no_default_group_fc1_weight_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[1].asnumpy(), no_default_group_fc1_bias_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[2].asnumpy(), no_default_group_fc2_weight_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[3].asnumpy(), no_default_group_fc2_bias_asgd, atol=1.e-5)
|
||||
assert np.allclose(loss_group_asgd, loss, atol=1.e-5, rtol=1e-3)
|
||||
|
||||
|
||||
def test_default_asgd_group_graph():
|
||||
"""
|
||||
Feature: Test ASGD optimizer
|
||||
Description: Test ASGD in Graph mode with parameter grouping
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
from .optimizer_utils import no_default_group_fc1_weight_asgd, no_default_group_fc1_bias_asgd, \
|
||||
no_default_group_fc2_weight_asgd, no_default_group_fc2_bias_asgd
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
config = {'name': 'ASGD', 'lr': 0.1, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001}
|
||||
loss, cells = build_network(config, is_group=True)
|
||||
assert np.allclose(cells.ax[0].asnumpy(), no_default_group_fc1_weight_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[1].asnumpy(), no_default_group_fc1_bias_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[2].asnumpy(), no_default_group_fc2_weight_asgd, atol=1.e-5)
|
||||
assert np.allclose(cells.ax[3].asnumpy(), no_default_group_fc2_bias_asgd, atol=1.e-5)
|
||||
assert np.allclose(loss_group_asgd, loss, atol=1.e-5, rtol=1e-3)
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
from .optimizer_utils import build_network, loss_default_rprop, loss_group_rprop, loss_not_default_rprop
|
||||
|
||||
|
||||
def test_default_rprop_pynative():
|
||||
"""
|
||||
Feature: Test Rprop optimizer
|
||||
Description: Test Rprop in Pynative mode with default parameter
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='Ascend')
|
||||
config = {'name': 'Rprop', 'lr': 0.01, 'etas': (0.5, 1.2), 'step_sizes': (1e-6, 50.), 'weight_decay': 0.0}
|
||||
loss = build_network(config)
|
||||
assert np.allclose(loss_default_rprop, loss, atol=1.e-5)
|
||||
|
||||
|
||||
def test_default_rprop_graph():
|
||||
"""
|
||||
Feature: Test Rprop optimizer
|
||||
Description: Test Rprop in Graph mode with default parameter
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
|
||||
config = {'name': 'Rprop', 'lr': 0.01, 'etas': (0.5, 1.2), 'step_sizes': (1e-6, 50.), 'weight_decay': 0.0}
|
||||
loss = build_network(config)
|
||||
assert np.allclose(loss_default_rprop, loss, atol=1.e-5)
|
||||
|
||||
|
||||
def test_no_default_rprop_pynative():
|
||||
"""
|
||||
Feature: Test Rprop optimizer
|
||||
Description: Test Rprop in Pynative mode with another set of parameter
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='Ascend')
|
||||
config = {'name': 'Rprop', 'lr': 0.001, 'etas': (0.6, 1.9), 'step_sizes': (1e-3, 20.), 'weight_decay': 0.0}
|
||||
loss = build_network(config)
|
||||
assert np.allclose(loss_not_default_rprop, loss, atol=1.e-5)
|
||||
|
||||
|
||||
def test_no_default_rprop_graph():
|
||||
"""
|
||||
Feature: Test Rprop optimizer
|
||||
Description: Test Rprop in Graph mode with another set of parameter
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
|
||||
config = {'name': 'Rprop', 'lr': 0.001, 'etas': (0.6, 1.9), 'step_sizes': (1e-3, 20.), 'weight_decay': 0.0}
|
||||
loss = build_network(config)
|
||||
assert np.allclose(loss_not_default_rprop, loss, atol=1.e-5)
|
||||
|
||||
|
||||
def test_default_rprop_group_pynative():
|
||||
"""
|
||||
Feature: Test Rprop optimizer
|
||||
Description: Test Rprop in Pynative mode with parameter grouping
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='Ascend')
|
||||
config = {'name': 'Rprop', 'lr': 0.001, 'etas': (0.6, 1.9), 'step_sizes': (1e-2, 10.), 'weight_decay': 0.0}
|
||||
loss = build_network(config, is_group=True)
|
||||
assert np.allclose(loss_group_rprop, loss, atol=1.e-5)
|
||||
|
||||
|
||||
def test_default_rprop_group_graph():
|
||||
"""
|
||||
Feature: Test Rprop optimizer
|
||||
Description: Test Rprop in Graph mode with parameter grouping
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
|
||||
config = {'name': 'Rprop', 'lr': 0.001, 'etas': (0.6, 1.9), 'step_sizes': (1e-2, 10.), 'weight_decay': 0.0}
|
||||
loss = build_network(config, is_group=True)
|
||||
assert np.allclose(loss_group_rprop, loss, atol=1.e-5)
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
from .optimizer_utils import build_network, loss_default_rprop, loss_group_rprop, loss_not_default_rprop
|
||||
|
||||
|
||||
def test_default_rprop_graph():
|
||||
"""
|
||||
Feature: Test Rprop optimizer
|
||||
Description: Test Rprop in Graph mode with default parameter
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
config = {'name': 'Rprop', 'lr': 0.01, 'etas': (0.5, 1.2), 'step_sizes': (1e-6, 50.), 'weight_decay': 0.0}
|
||||
loss = build_network(config)
|
||||
assert np.allclose(loss_default_rprop, loss, atol=1.e-5)
|
||||
|
||||
|
||||
def test_no_default_rprop_graph():
|
||||
"""
|
||||
Feature: Test Rprop optimizer
|
||||
Description: Test Rprop in Graph mode with another set of parameter
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
config = {'name': 'Rprop', 'lr': 0.001, 'etas': (0.6, 1.9), 'step_sizes': (1e-3, 20.), 'weight_decay': 0.0}
|
||||
loss = build_network(config)
|
||||
assert np.allclose(loss_not_default_rprop, loss, atol=1.e-5)
|
||||
|
||||
|
||||
def test_default_rprop_group_graph():
|
||||
"""
|
||||
Feature: Test Rprop optimizer
|
||||
Description: Test Rprop in Graph mode with parameter grouping
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
config = {'name': 'Rprop', 'lr': 0.001, 'etas': (0.6, 1.9), 'step_sizes': (1e-2, 10.), 'weight_decay': 0.0}
|
||||
loss = build_network(config, is_group=True)
|
||||
assert np.allclose(loss_group_rprop, loss, atol=1.e-5)
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
from .optimizer_utils import build_network, loss_default_rprop, loss_group_rprop, loss_not_default_rprop
|
||||
|
||||
|
||||
def test_default_rprop_pynative():
|
||||
"""
|
||||
Feature: Test Rprop optimizer
|
||||
Description: Test Rprop in Pynative mode with default parameter
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
config = {'name': 'Rprop', 'lr': 0.01, 'etas': (0.5, 1.2), 'step_sizes': (1e-6, 50.), 'weight_decay': 0.0}
|
||||
loss = build_network(config)
|
||||
assert np.allclose(loss_default_rprop, loss, atol=1.e-5)
|
||||
|
||||
|
||||
def test_default_rprop_graph():
|
||||
"""
|
||||
Feature: Test Rprop optimizer
|
||||
Description: Test Rprop in Graph mode with default parameter
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
config = {'name': 'Rprop', 'lr': 0.01, 'etas': (0.5, 1.2), 'step_sizes': (1e-6, 50.), 'weight_decay': 0.0}
|
||||
loss = build_network(config)
|
||||
assert np.allclose(loss_default_rprop, loss, atol=1.e-5)
|
||||
|
||||
|
||||
def test_no_default_rprop_pynative():
|
||||
"""
|
||||
Feature: Test Rprop optimizer
|
||||
Description: Test Rprop in Pynative mode with another set of parameter
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
config = {'name': 'Rprop', 'lr': 0.001, 'etas': (0.6, 1.9), 'step_sizes': (1e-3, 20.), 'weight_decay': 0.0}
|
||||
loss = build_network(config)
|
||||
assert np.allclose(loss_not_default_rprop, loss, atol=1.e-5)
|
||||
|
||||
|
||||
def test_no_default_rprop_graph():
|
||||
"""
|
||||
Feature: Test Rprop optimizer
|
||||
Description: Test Rprop in Graph mode with another set of parameter
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
config = {'name': 'Rprop', 'lr': 0.001, 'etas': (0.6, 1.9), 'step_sizes': (1e-3, 20.), 'weight_decay': 0.0}
|
||||
loss = build_network(config)
|
||||
assert np.allclose(loss_not_default_rprop, loss, atol=1.e-5)
|
||||
|
||||
|
||||
def test_default_rprop_group_pynative():
|
||||
"""
|
||||
Feature: Test Rprop optimizer
|
||||
Description: Test Rprop in Pynative mode with parameter grouping
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
config = {'name': 'Rprop', 'lr': 0.001, 'etas': (0.6, 1.9), 'step_sizes': (1e-2, 10.), 'weight_decay': 0.0}
|
||||
loss = build_network(config, is_group=True)
|
||||
assert np.allclose(loss_group_rprop, loss, atol=1.e-5)
|
||||
|
||||
|
||||
def test_default_rprop_group_graph():
|
||||
"""
|
||||
Feature: Test Rprop optimizer
|
||||
Description: Test Rprop in Graph mode with parameter grouping
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
config = {'name': 'Rprop', 'lr': 0.001, 'etas': (0.6, 1.9), 'step_sizes': (1e-2, 10.), 'weight_decay': 0.0}
|
||||
loss = build_network(config, is_group=True)
|
||||
assert np.allclose(loss_group_rprop, loss, atol=1.e-5)
|
Loading…
Reference in New Issue