!25557 Add ASGD and Rprop optimizer to Mindspore

Merge pull request !25557 from wanyiming/optimizer
This commit is contained in:
i-robot 2021-11-23 07:17:11 +00:00 committed by Gitee
commit 15cc4673a0
11 changed files with 1172 additions and 1 deletions

View File

@ -23,6 +23,8 @@ from .momentum import Momentum
from .adam import Adam, AdamWeightDecay, AdamOffload
from .lamb import Lamb
from .sgd import SGD
from .asgd import ASGD
from .rprop import Rprop
from .lars import LARS
from .ftrl import FTRL
from .rmsprop import RMSProp
@ -33,4 +35,4 @@ from .thor import thor
from .adafactor import AdaFactor
__all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'AdamWeightDecay', 'LazyAdam', 'AdamOffload',
'Lamb', 'SGD', 'FTRL', 'RMSProp', 'ProximalAdagrad', 'Adagrad', 'thor', 'AdaFactor']
'Lamb', 'SGD', 'ASGD', 'Rprop', 'FTRL', 'RMSProp', 'ProximalAdagrad', 'Adagrad', 'thor', 'AdaFactor']

190
mindspore/nn/optim/asgd.py Executable file
View File

@ -0,0 +1,190 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""asgd"""
from mindspore.ops import functional as F, operations as P
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
import mindspore
from mindspore._checkparam import Validator as validator
from .optimizer import Optimizer
from .optimizer import opt_init_args_register
class ASGD(Optimizer):
r"""
Implements Average Stochastic Gradient Descent.
Introduction to ASGD can be found at `Acceleration of stochastic approximation by average
<http://dl.acm.org/citation.cfm?id=131098>`_.
The updating formulas are as follows:
\begin{gather*}
w_{t} = w_{t-1} * (1 - \lambda * \eta_{t-1}) - \eta_{t-1} * g_{t} \\
ax_{t} = (w_t - ax_{t-1}) * \mu_{t-1} \\
\eta_{t} = \frac{1.}{(1 + \lambda * lr * t)^\alpha} \\
\mu_{t} = \frac{1}{\max(1, t - t0)}
\end{gather*}
:math:`\lambda` represents the decay term, :math:`\mu` and :math:`\eta` are tracked to
update :math:`ax` and :math:`w`, :math:`t0` represents the point of starting averaging,
:math:`\alpha` represents the power for eta update, :math:`ax` represents the averaged
parameter value, :math:`t` represents the current step, :math:`g` represents `gradients`,
:math:`w` represents `params`.
Note:
If parameters are not grouped, the `weight_decay` in optimizer will be applied on the parameters without 'beta'
or 'gamma' in their names. Users can group parameters to change the strategy of decaying weight. When parameters
are grouped, each group can set `weight_decay`, if not, the `weight_decay` in optimizer will be applied.
Args:
params (Union[list[Parameter], list[dict]]): Must be list of `Parameter` or list of `dict`. When the
`parameters` is a list of `dict`, the "params", "lr", "weight_decay", "grad_centralization" and
"order_params" are the keys can be parsed.
- params: Required. Parameters in current group. The value must be a list of `Parameter`.
- lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
If not, the `learning_rate` in optimizer will be used. Fixed and dynamic learning rate are supported.
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
will be used. If not, the `weight_decay` in the optimizer will be used.
- grad_centralization: Optional. Must be Boolean. If "grad_centralization" is in the keys, the set value
will be used. If not, the `grad_centralization` is False by default. This configuration only works on the
convolution layer.
- order_params: Optional. When parameters is grouped, this usually is used to maintain the order of
parameters that appeared in the network to improve performance. The value should be parameters whose
order will be followed in optimizer.
If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
one group of `params`.
learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]):
- float: The fixed learning rate value. Must be equal to or greater than 0.
- int: The fixed learning rate value. Must be equal to or greater than 0. It will be converted to float.
- Tensor: Its value should be a scalar or a 1-D vector. For scalar, fixed learning rate will be applied.
For vector, learning rate is dynamic, then the i-th step will take the i-th value as the learning rate.
- Iterable: Learning rate is dynamic. The i-th step will take the i-th value as the learning rate.
- LearningRateSchedule: Learning rate is dynamic. During training, the optimizer calls the instance of
LearningRateSchedule with step as the input to get the learning rate of current step.
lambd (float): The decay term. Default: 1e-4.
alpha (float): The power for eta update. Default: 0.75.
t0 (float): The point of starting averaging. Default: 1e6.
weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs:
Tensor[bool], the value is True.
Raises:
TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.
TypeError: If element of `parameters` is neither Parameter nor dict.
TypeError: If `lambd`, `alpha` or `t0` is not a float.
TypeError: If `weight_decay` is neither float nor int.
ValueError: If `weight_decay` is less than 0.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> net = Net()
>>> #1) All parameters use the same learning rate and weight decay
>>> optim = nn.ASGD(params=net.trainable_params())
>>>
>>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params,'grad_centralization':True},
... {'params': no_conv_params, 'lr': 0.01},
... {'order_params': net.trainable_params()}]
>>> optim = nn.ASGD(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use default learning rate of 0.1 default weight decay of 0.0 and grad
>>> # centralization of True.
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0 and grad
>>> # centralization of False.
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = Model(net, loss_fn=loss, optimizer=optim)
"""
@opt_init_args_register
def __init__(self, params, learning_rate=0.1, lambd=1e-4, alpha=0.75, t0=1e6, weight_decay=0.):
super(ASGD, self).__init__(learning_rate, params, weight_decay)
validator.check_value_type("lambd", lambd, [float], self.cls_name)
validator.check_value_type("alpha", alpha, [float], self.cls_name)
validator.check_value_type("t0", t0, [float], self.cls_name)
self.lambd = lambd
self.alpha = alpha
self.t0 = Tensor([t0], dtype=mstype.float32)
mu, eta = [], []
for param in self.parameters:
mu.append(Parameter(Tensor(1., dtype=mstype.float32), name='mu_'+param.name))
eta.append(Parameter(Tensor(0., dtype=mstype.float32), name='eta_'+param.name))
self.lens = len(self.parameters)
self.mu = mindspore.ParameterTuple(mu)
self.eta = mindspore.ParameterTuple(eta)
self.step = Parameter(Tensor(1., dtype=mstype.float32), name='step')
self.ax = self.parameters.clone(prefix="ax_", init='zeros')
self.pow = P.Pow()
self.maximum = P.Maximum()
self.assign = P.Assign()
self.assignadd = P.AssignAdd()
self.assignsub = P.AssignSub()
self.cast = P.Cast()
self.squeeze = P.Squeeze()
def construct(self, gradients):
gradients = self.decay_weight(gradients)
gradients = self.gradients_centralization(gradients)
gradients = self.scale_grad(gradients)
lrs = self.get_lr()
success = True
for index, (grad, param, mu, eta, ax) in enumerate(zip(gradients, self.parameters, self.mu, self.eta, self.ax)):
lr = lrs[index] if self.is_group_lr else lrs
if self.step == 1.:
self.assign(eta, lr)
param_fp32 = self.cast(param, mstype.float32)
gradient_fp32 = self.cast(grad, mstype.float32)
ax_fp32 = self.cast(ax, mstype.float32)
param_fp32 = param_fp32 * (1. - self.lambd * eta) - eta * gradient_fp32
self.assign(param, self.cast(param_fp32, param.dtype))
if mu != 1:
self.assignadd(ax, self.cast((param_fp32 - ax_fp32) * mu, ax.dtype))
else:
self.assign(ax, param)
self.assign(eta, lr / (self.pow((1. + (self.lambd * lr * self.step)), self.alpha)))
self.assign(mu, 1. / self.squeeze(self.maximum(1., self.step - self.t0)))
success = F.depend(success, self.assignadd(self.step, 1.))
return success

215
mindspore/nn/optim/rprop.py Executable file
View File

@ -0,0 +1,215 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""rprop"""
from mindspore import ops
from mindspore.ops import functional as F, operations as P
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from .optimizer import Optimizer
from .optimizer import opt_init_args_register
class Rprop(Optimizer):
r"""
Implements Resilient backpropagation.
Further information about this implementation can be found at `A Direct Adaptive Method for Faster Backpropagation
Learning: The RPROP Algorithm <http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.21.1417>`_.
The updating formulas are as follows:
.. math::
\begin{gather*}
&\hspace{0mm} \textbf{if} \: g_{t-1} g_t > 0 \\
&\hspace{5mm} \Delta_t \leftarrow \mathrm{min}(\Delta_{t-1} \eta_{+}, \Delta_{max}) \\
&\hspace{0mm} \textbf{else if} \: g_{t-1} g_t < 0 \\
&\hspace{5mm} \Delta_t \leftarrow \mathrm{max}(\Delta_{t-1} \eta_{-}, \Delta_{min}) \\
&\hspace{mm} \textbf{else} \: \\
&\hspace{5mm} \Delta_t \leftarrow \Delta_{t-1} \\
&\hspace{0mm} w_{t} \leftarrow w_{t-1}- \Delta_{t} \mathrm{sign}(g_t) \\
\end{gather*}
:math:`\Delta_{min/max}` represents the min/max step size, :math:`\eta_{+/-}` represents the factors of
etaminus and etaplus, :math:`g` represents `gradients`, :math:`w` represents `parameters`.
Note:
If parameters are not grouped, the `weight_decay` in optimizer will be applied on the parameters without 'beta'
or 'gamma' in their names. Users can group parameters to change the strategy of decaying weight. When parameters
are grouped, each group can set `weight_decay`, if not, the `weight_decay` in optimizer will be applied.
Args:
params (Union[list[Parameter], list[dict]]): Must be list of `Parameter` or list of `dict`. When the
`parameters` is a list of `dict`, the "params", "lr", "weight_decay", "grad_centralization" and
"order_params" are the keys can be parsed.
- params: Required. Parameters in current group. The value must be a list of `Parameter`.
- lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
If not, the `learning_rate` in optimizer will be used. Fixed and dynamic learning rate are supported.
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
will be used. If not, the `weight_decay` in the optimizer will be used.
- grad_centralization: Optional. Must be Boolean. If "grad_centralization" is in the keys, the set value
will be used. If not, the `grad_centralization` is False by default. This configuration only works on the
convolution layer.
- order_params: Optional. When parameters is grouped, this usually is used to maintain the order of
parameters that appeared in the network to improve performance. The value should be parameters whose
order will be followed in optimizer.
If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
one group of `params`.
learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]):
- float: The fixed learning rate value. Must be equal to or greater than 0.
- int: The fixed learning rate value. Must be equal to or greater than 0. It will be converted to float.
- Tensor: Its value should be a scalar or a 1-D vector. For scalar, fixed learning rate will be applied.
For vector, learning rate is dynamic, then the i-th step will take the i-th value as the learning rate.
- Iterable: Learning rate is dynamic. The i-th step will take the i-th value as the learning rate.
- LearningRateSchedule: Learning rate is dynamic. During training, the optimizer calls the instance of
LearningRateSchedule with step as the input to get the learning rate of current step.
etas (tuple[float, float]): The factor of multiplicative increasing or
descreasing(etaminus, etaplus).
step_sizes(tuple[float, float]): The allowed minimal and maximal step size(min_step_sizes, max_step_size).
weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs:
Tensor[bool], the value is True.
Raises:
TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.
TypeError: If element of `parameters` is neither Parameter nor dict.
TypeError: If `step_sizes` or `etas` is not a tuple.
ValueError: If maximal step size is less than minimal step size.
ValueError: If the length of `step_sizes` or `ets` is not equal to 2.
TypeError: If the element in `etas` or `step_sizes` is not a float.
ValueError: If `etaminus` is not in the range of (0, 1) or `etaplus` is not greater than 1.
TypeError: If `weight_decay` is neither float nor int.
ValueError: If `weight_decay` is less than 0.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> net = Net()
>>> #1) All parameters use the same learning rate and weight decay
>>> optim = nn.Rprop(params=net.trainable_params())
>>>
>>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params,'grad_centralization':True},
... {'params': no_conv_params, 'lr': 0.01},
... {'order_params': net.trainable_params()}]
>>> optim = nn.Rprop(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use default learning rate of 0.1 default weight decay of 0.0 and grad
>>> # centralization of True.
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0 and grad
>>> # centralization of False.
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = Model(net, loss_fn=loss, optimizer=optim)
"""
@opt_init_args_register
def __init__(self, params, learning_rate=0.1, etas=(0.5, 1.2), step_sizes=(1e-6, 50.), weight_decay=0.1):
super(Rprop, self).__init__(learning_rate, params, weight_decay)
if not isinstance(etas, tuple):
raise TypeError("For Rprop, etas should be a tuple, but got {}.".format(type(etas)))
if len(etas) != 2:
raise ValueError("For Rprop, etas should be a tuple with the size of 2, but got {}.".format(len(etas)))
if not isinstance(step_sizes, tuple):
raise TypeError("For Rprop, step_sizes should be a tuple, but got {}.".format(type(etas)))
if len(step_sizes) != 2:
raise ValueError("For Rprop, step_sizes should be a tuple with the size of 2, "
"but got {}.".format(len(step_sizes)))
if step_sizes[0] > step_sizes[1]:
raise ValueError("For Rprop, maximal step size should not be less than minimal step size, "
"but got {} > {}.".format(step_sizes[0], step_sizes[1]))
validator.check_float_range(etas[0], 0.0, 1.0, Rel.INC_NEITHER, "etaminus", self.cls_name)
validator.check_value_type("etaplus", etas[1], [float], self.cls_name)
if etas[1] <= 1.0:
raise ValueError("For Rprop, etaplus should be greater than 1.0, but got etaplus {}.".format(etas[1]))
validator.check_value_type("min_step_sizes", step_sizes[0], [float], self.cls_name)
validator.check_value_type("max_step_sizes", step_sizes[1], [float], self.cls_name)
self.etaminus, self.etaplus = etas
self.step_size_min, self.step_size_max = step_sizes
self.prev = self.parameters.clone(prefix="prev", init='zeros')
self.step_size = self.parameters.clone(prefix="step_size", init='zeros')
self.step = Parameter(Tensor(0., dtype=mstype.float32), name='step')
self.fill = P.Fill()
self.sign = P.Sign()
self.assign = P.Assign()
self.assignadd = P.AssignAdd()
self.cast = P.Cast()
self.select = P.Select()
self.ones_like = P.OnesLike()
def construct(self, gradients):
gradients = self.decay_weight(gradients)
gradients = self.gradients_centralization(gradients)
gradients = self.scale_grad(gradients)
lrs = self.get_lr()
success = True
for index, (grad, param, prev, step_size) in enumerate(zip(gradients, self.parameters,
self.prev, self.step_size)):
lr = lrs[index] if self.is_group_lr else lrs
if self.step == 0.:
step_size_fp32 = self.ones_like(step_size) * lr
else:
step_size_fp32 = self.cast(step_size, mstype.float32)
gradient_fp32 = self.cast(grad, mstype.float32)
param_fp32 = self.cast(param, mstype.float32)
sign = self.sign(gradient_fp32 * prev)
sign = self.select(sign > 0, self.fill(mstype.float32, sign.shape, self.etaplus), sign)
sign = self.select(sign < 0, self.fill(mstype.float32, sign.shape, self.etaminus), sign)
sign = self.select(sign == 0, self.fill(mstype.float32, sign.shape, 1.), sign)
step_size_fp32 = ops.clip_by_value(step_size_fp32 * sign, self.step_size_min, self.step_size_max)
gradient_update = self.select(sign == self.etaminus, self.fill(mstype.float32, sign.shape, 0.),
gradient_fp32)
next_param = param_fp32 - self.sign(gradient_update) * step_size_fp32
self.assign(param, self.cast(next_param, param.dtype))
self.assign(prev, self.cast(gradient_update, prev.dtype))
self.assign(step_size, self.cast(step_size_fp32, step_size.dtype))
success = F.depend(success, self.assignadd(self.step, 1.))
return success

View File

View File

@ -0,0 +1,208 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore
from mindspore import nn, Tensor
from mindspore.ops import operations as P
from mindspore.nn.optim import ASGD
from mindspore.nn.optim import Rprop
np.random.seed(1024)
fc1_weight = np.array([[0.72346634, 0.95608497, 0.4084163, 0.18627149,
0.6942514, 0.39767185, 0.24918061, 0.4548748],
[0.7203382, 0.19086994, 0.76286614, 0.87920564,
0.3169892, 0.9462494, 0.62827677, 0.27504718],
[0.3544535, 0.2524781, 0.5370583, 0.8313121,
0.6670143, 0.0488653, 0.62225235, 0.7546456],
[0.17985944, 0.05106374, 0.31064633, 0.4863033,
0.848814, 0.5523157, 0.20295663, 0.7213356]]).astype("float32")
fc1_bias = np.array([0.79708564, 0.13728078, 0.66322654, 0.88128525]).astype("float32")
fc2_weight = np.array([[0.8473515, 0.50923985, 0.42287776, 0.29769543]]).astype("float32")
fc2_bias = np.array([0.09996348]).astype("float32")
def make_fake_data():
"""
make fake data
"""
data, label = [], []
for i in range(20):
data.append(mindspore.Tensor(np.array(np.ones((2, 8)) * i, dtype=np.float32)))
label.append(mindspore.Tensor(np.array(np.ones((2, 1)) * (i + 1), dtype=np.float32)))
return data, label
class NetWithLoss(nn.Cell):
"""
build net with loss
"""
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.network = network
self.loss = nn.MSELoss(reduction='sum')
def construct(self, x, label):
out = self.network(x)
loss = self.loss(out, label)
return loss
class FakeNet(nn.Cell):
"""
build fake net
"""
def __init__(self):
super(FakeNet, self).__init__()
self.fc1 = nn.Dense(in_channels=8, out_channels=4, weight_init=Tensor(fc1_weight), bias_init=Tensor(fc1_bias))
self.fc2 = nn.Dense(in_channels=4, out_channels=1, weight_init=Tensor(fc2_weight), bias_init=Tensor(fc2_bias))
self.relu = nn.ReLU()
self.reducemean = P.ReduceMean()
def construct(self, x):
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
def _initialize_weights(self):
"""
parameter initialization
"""
self.init_parameters_data()
for name, m in self.cells_and_names():
if name == 'fc1':
m.weight.set_data(Tensor(fc1_weight))
m.bias.set_data(Tensor(fc1_bias))
elif name == 'fc2':
m.weight.set_data(Tensor(fc2_weight))
m.bias.set_data(Tensor(fc2_bias))
def build_network(opt_config, is_group=False):
"""
Construct training
"""
losses = []
net = FakeNet()
networkwithloss = NetWithLoss(net)
networkwithloss.set_train()
if is_group:
fc1_params = list(filter(lambda x: 'fc1' in x.name, networkwithloss.trainable_params()))
fc2_params = list(filter(lambda x: 'fc1' not in x.name, networkwithloss.trainable_params()))
if opt_config['name'] == 'ASGD':
params = [{'params': fc1_params, 'weight_decay': 0.01, 'lr': 0.001}, {'params': fc2_params, 'lr': 0.1}]
else:
params = [{'params': fc1_params, 'lr': 0.001}, {'params': fc2_params, 'lr': 0.1}]
else:
params = networkwithloss.trainable_params()
if opt_config['name'] == 'ASGD':
net_opt = ASGD(params, learning_rate=opt_config['lr'], lambd=opt_config['lambd'], alpha=opt_config['alpha'],
t0=opt_config['t0'], weight_decay=opt_config['weight_decay'])
elif opt_config['name'] == 'Rprop':
net_opt = Rprop(params, learning_rate=opt_config['lr'], etas=opt_config['etas'],
step_sizes=opt_config['step_sizes'], weight_decay=0.0)
trainonestepcell = mindspore.nn.TrainOneStepCell(networkwithloss, net_opt)
data, label = make_fake_data()
for i in range(20):
loss = trainonestepcell(data[i], label[i])
losses.append(loss.asnumpy())
if opt_config['name'] == 'ASGD':
return np.array(losses), net_opt
return np.array(losses)
loss_default_asgd = np.array([3.01246792e-01, 1.20041794e+02, 1.38681079e+03, 2.01250820e+01,
3.27283554e+01, 4.76963005e+01, 6.47094269e+01, 8.34786530e+01,
1.03742706e+02, 1.25265739e+02, 1.47835190e+02, 1.71259613e+02,
1.95367035e+02, 2.20003204e+02, 2.45029831e+02, 2.70323456e+02,
2.95774048e+02, 3.21283752e+02, 3.46765594e+02, 3.72143097e+02], dtype=np.float32)
loss_not_default_asgd = np.array([3.01246792e-01, 1.26019104e+02, 1.90600449e+02, 9.70605755e+00,
2.98419113e+01, 3.68430023e+02, 1.06318066e+04, 1.35017746e+02,
1.68673813e+02, 2.05914215e+02, 2.46694992e+02, 2.90972443e+02,
3.38703430e+02, 3.89845123e+02, 4.44355103e+02, 5.02191406e+02,
5.63312500e+02, 6.27676941e+02, 6.95244202e+02, 7.65973816e+02], dtype=np.float32)
loss_group_asgd = np.array([3.01246792e-01, 7.26708527e+01, 2.84905312e+05, 4.17499258e+04,
1.46797949e+04, 5.07966602e+03, 1.70935132e+03, 5.47094910e+02,
1.59216995e+02, 3.78818207e+01, 5.18196869e+00, 2.62275129e-03,
2.09768796e+00, 5.23108435e+00, 7.78943682e+00, 9.57108879e+00,
1.07310610e+01, 1.14618425e+01, 1.19147835e+01, 1.21936722e+01], dtype=np.float32)
loss_default_rprop = np.array([3.01246792e-01, 1.19871742e+02, 4.13467163e+02, 8.09146179e+02,
1.22364807e+03, 1.56787573e+03, 1.75733594e+03, 1.72866272e+03,
1.46183936e+03, 1.00406335e+03, 4.84076874e+02, 9.49734650e+01,
2.00592804e+01, 1.87920704e+01, 1.53733969e+01, 1.85836582e+01,
5.21527790e-02, 2.01522671e-02, 7.19913816e+00, 8.52459526e+00], dtype=np.float32)
loss_not_default_rprop = np.array([3.0124679e-01, 1.2600269e+02, 4.7351608e+02, 1.0220379e+03,
1.7181555e+03, 2.4367019e+03, 2.9170872e+03, 2.7243464e+03,
1.4999669e+03, 7.5820435e+01, 1.0590715e+03, 5.4336096e+02,
7.0162407e+01, 8.2754419e+02, 9.6329260e+02, 3.4475109e+01,
5.3843134e+02, 6.0064526e+02, 1.1046149e+02, 3.5530117e+03], dtype=np.float32)
loss_group_rprop = np.array([3.0124679e-01, 7.1360558e+01, 4.8910957e+01, 2.1730331e+02,
3.0747052e+02, 5.2734237e+00, 5.6865869e+00, 1.7116127e+02,
2.0539343e+02, 2.2993685e+01, 2.6194101e+02, 2.8772815e+02,
2.4236647e+01, 3.9299741e+02, 3.5600668e+02, 1.4759110e+01,
7.2244568e+02, 8.1952783e+02, 9.8913864e+01, 1.1141744e+03], dtype=np.float32)
default_fc1_weight_asgd = np.array([[-0.9451941, -0.71258026, -1.2602371, -1.4823773,
-0.974408, -1.2709816, -1.4194703, -1.2137808],
[-1.5341775, -2.0636342, -1.4916497, -1.3753126,
-1.9375193, -1.308271, -1.6262367, -1.9794592],
[-1.9886293, -2.0906024, -1.8060291, -1.5117803,
-1.6760755, -2.2942104, -1.7208353, -1.5884445],
[-2.071215, -2.2000103, -1.9404325, -1.7647781,
-1.4022746, -1.6987679, -2.0481179, -1.5297506]], dtype=np.float32)
default_fc1_bias_asgd = np.array([-0.17978168, -1.0764512, -0.578816, -0.2928958], dtype=np.float32)
default_fc2_weight_asgd = np.array([[4.097412, 6.2694297, 5.9203916, 5.3845487]], dtype=np.float32)
default_fc2_bias_asgd = np.array([6.904814], dtype=np.float32)
no_default_fc1_weight_asgd = np.array([[-1.3406217, -1.1080127, -1.655658, -1.8777936,
-1.3698348, -1.6664025, -1.8148884, -1.6092018],
[-1.1475986, -1.6770473, -1.1050745, -0.98873824,
-1.5509329, -0.9216978, -1.2396574, -1.5928726],
[-1.2329121, -1.334883, -1.050313, -0.756071,
-0.92036265, -1.5384867, -0.96512324, -0.8327349],
[-1.0685704, -1.1973612, -0.9377885, -0.7621386,
-0.39964262, -0.69612867, -1.0454736, -0.52711576]], dtype=np.float32)
no_default_fc1_bias_asgd = np.array([0.41264832, -0.19961096, 0.37743938, 0.65807366], dtype=np.float32)
no_default_fc2_weight_asgd = np.array([[-5.660916, -5.9415145, -5.1402636, -4.199707]], dtype=np.float32)
no_default_fc2_bias_asgd = np.array([0.5082278], dtype=np.float32)
no_default_group_fc1_weight_asgd = np.array([[-32.526627, -32.29401, -32.8416, -33.06367, -32.55584,
-32.852345, -33.000767, -32.795143],
[-33.164936, -33.69432, -33.12241, -33.006073, -33.568207,
-32.9391, -33.256996, -33.61015],
[-33.118973, -33.220943, -32.936436, -32.642193, -32.806488,
-33.424484, -32.85125, -32.718857],
[-30.155754, -30.284513, -30.025005, -29.849358, -29.486917,
-29.783375, -30.132658, -29.614393]], dtype=np.float32)
no_default_group_fc1_bias_asgd = np.array([-15.838092, -16.811989, -16.078112, -14.289094], dtype=np.float32)
no_default_group_fc2_weight_asgd = np.array([[1288.7146, 1399.3041, 1292.8445, 1121.4629]], dtype=np.float32)
no_default_group_fc2_bias_asgd = np.array([18.513494], dtype=np.float32)

View File

@ -0,0 +1,125 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.context as context
from .optimizer_utils import build_network, loss_not_default_asgd, loss_default_asgd, loss_group_asgd
def test_default_asgd_pynative():
"""
Feature: Test ASGD optimizer
Description: Test ASGD in Pynative mode with default parameter
Expectation: Loss values and parameters conform to preset values.
"""
from .optimizer_utils import default_fc1_weight_asgd, \
default_fc1_bias_asgd, default_fc2_weight_asgd, default_fc2_bias_asgd
context.set_context(mode=context.PYNATIVE_MODE, device_target='Ascend')
config = {'name': 'ASGD', 'lr': 0.01, 'lambd': 1e-4, 'alpha': 0.75, 't0': 1e6, 'weight_decay': 0.0}
loss, cells = build_network(config)
assert np.allclose(cells.ax[0].asnumpy(), default_fc1_weight_asgd, atol=1.e-5)
assert np.allclose(cells.ax[1].asnumpy(), default_fc1_bias_asgd, atol=1.e-5)
assert np.allclose(cells.ax[2].asnumpy(), default_fc2_weight_asgd, atol=1.e-5)
assert np.allclose(cells.ax[3].asnumpy(), default_fc2_bias_asgd, atol=1.e-5)
assert np.allclose(loss_default_asgd, loss, atol=1.e-5)
def test_default_asgd_graph():
"""
Feature: Test ASGD optimizer
Description: Test ASGD in Graph mode with default parameter
Expectation: Loss values and parameters conform to preset values.
"""
from .optimizer_utils import default_fc1_weight_asgd, \
default_fc1_bias_asgd, default_fc2_weight_asgd, default_fc2_bias_asgd
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
config = {'name': 'ASGD', 'lr': 0.01, 'lambd': 1e-4, 'alpha': 0.75, 't0': 1e6, 'weight_decay': 0.0}
loss, cells = build_network(config)
assert np.allclose(cells.ax[0].asnumpy(), default_fc1_weight_asgd, atol=1.e-5)
assert np.allclose(cells.ax[1].asnumpy(), default_fc1_bias_asgd, atol=1.e-5)
assert np.allclose(cells.ax[2].asnumpy(), default_fc2_weight_asgd, atol=1.e-5)
assert np.allclose(cells.ax[3].asnumpy(), default_fc2_bias_asgd, atol=1.e-5)
assert np.allclose(loss_default_asgd, loss, atol=1.e-5)
def test_no_default_asgd_pynative():
"""
Feature: Test ASGD optimizer
Description: Test ASGD in Pynative mode with another set of parameter
Expectation: Loss values and parameters conform to preset values.
"""
from .optimizer_utils import no_default_fc1_weight_asgd, \
no_default_fc1_bias_asgd, no_default_fc2_weight_asgd, no_default_fc2_bias_asgd
context.set_context(mode=context.PYNATIVE_MODE, device_target='Ascend')
config = {'name': 'ASGD', 'lr': 0.001, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001}
loss, cells = build_network(config)
assert np.allclose(cells.ax[0].asnumpy(), no_default_fc1_weight_asgd, atol=1.e-5)
assert np.allclose(cells.ax[1].asnumpy(), no_default_fc1_bias_asgd, atol=1.e-5)
assert np.allclose(cells.ax[2].asnumpy(), no_default_fc2_weight_asgd, atol=1.e-5)
assert np.allclose(cells.ax[3].asnumpy(), no_default_fc2_bias_asgd, atol=1.e-5)
assert np.allclose(loss_not_default_asgd, loss, atol=1.e-5, rtol=1e-3)
def test_no_default_asgd_graph():
"""
Feature: Test ASGD optimizer
Description: Test ASGD in Graph mode with another set of parameter
Expectation: Loss values and parameters conform to preset values.
"""
from .optimizer_utils import no_default_fc1_weight_asgd, \
no_default_fc1_bias_asgd, no_default_fc2_weight_asgd, no_default_fc2_bias_asgd
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
config = {'name': 'ASGD', 'lr': 0.001, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001}
loss, cells = build_network(config)
assert np.allclose(cells.ax[0].asnumpy(), no_default_fc1_weight_asgd, atol=1.e-5)
assert np.allclose(cells.ax[1].asnumpy(), no_default_fc1_bias_asgd, atol=1.e-5)
assert np.allclose(cells.ax[2].asnumpy(), no_default_fc2_weight_asgd, atol=1.e-5)
assert np.allclose(cells.ax[3].asnumpy(), no_default_fc2_bias_asgd, atol=1.e-5)
assert np.allclose(loss_not_default_asgd, loss, atol=1.e-5, rtol=1e-3)
def test_default_asgd_group_pynative():
"""
Feature: Test ASGD optimizer
Description: Test ASGD in Pynative mode with parameter grouping
Expectation: Loss values and parameters conform to preset values.
"""
from .optimizer_utils import no_default_group_fc1_weight_asgd, no_default_group_fc1_bias_asgd, \
no_default_group_fc2_weight_asgd, no_default_group_fc2_bias_asgd
context.set_context(mode=context.PYNATIVE_MODE, device_target='Ascend')
config = {'name': 'ASGD', 'lr': 0.1, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001}
loss, cells = build_network(config, is_group=True)
assert np.allclose(cells.ax[0].asnumpy(), no_default_group_fc1_weight_asgd, atol=1.e-5)
assert np.allclose(cells.ax[1].asnumpy(), no_default_group_fc1_bias_asgd, atol=1.e-5)
assert np.allclose(cells.ax[2].asnumpy(), no_default_group_fc2_weight_asgd, atol=1.e-5)
assert np.allclose(cells.ax[3].asnumpy(), no_default_group_fc2_bias_asgd, atol=1.e-5)
assert np.allclose(loss_group_asgd, loss, atol=1.e-5, rtol=1e-3)
def test_default_asgd_group_graph():
"""
Feature: Test ASGD optimizer
Description: Test ASGD in Graph mode with parameter grouping
Expectation: Loss values and parameters conform to preset values.
"""
from .optimizer_utils import no_default_group_fc1_weight_asgd, no_default_group_fc1_bias_asgd, \
no_default_group_fc2_weight_asgd, no_default_group_fc2_bias_asgd
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
config = {'name': 'ASGD', 'lr': 0.1, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001}
loss, cells = build_network(config, is_group=True)
assert np.allclose(cells.ax[0].asnumpy(), no_default_group_fc1_weight_asgd, atol=1.e-5)
assert np.allclose(cells.ax[1].asnumpy(), no_default_group_fc1_bias_asgd, atol=1.e-5)
assert np.allclose(cells.ax[2].asnumpy(), no_default_group_fc2_weight_asgd, atol=1.e-5)
assert np.allclose(cells.ax[3].asnumpy(), no_default_group_fc2_bias_asgd, atol=1.e-5)
assert np.allclose(loss_group_asgd, loss, atol=1.e-5, rtol=1e-3)

View File

@ -0,0 +1,72 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.context as context
from .optimizer_utils import build_network, loss_not_default_asgd, loss_default_asgd, loss_group_asgd
def test_default_asgd_graph():
"""
Feature: Test ASGD optimizer
Description: Test ASGD in Graph mode with default parameter
Expectation: Loss values and parameters conform to preset values.
"""
from .optimizer_utils import default_fc1_weight_asgd, \
default_fc1_bias_asgd, default_fc2_weight_asgd, default_fc2_bias_asgd
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
config = {'name': 'ASGD', 'lr': 0.01, 'lambd': 1e-4, 'alpha': 0.75, 't0': 1e6, 'weight_decay': 0.0}
loss, cells = build_network(config)
assert np.allclose(cells.ax[0].asnumpy(), default_fc1_weight_asgd, atol=1.e-5)
assert np.allclose(cells.ax[1].asnumpy(), default_fc1_bias_asgd, atol=1.e-5)
assert np.allclose(cells.ax[2].asnumpy(), default_fc2_weight_asgd, atol=1.e-5)
assert np.allclose(cells.ax[3].asnumpy(), default_fc2_bias_asgd, atol=1.e-5)
assert np.allclose(loss_default_asgd, loss, atol=1.e-5)
def test_no_default_asgd_graph():
"""
Feature: Test ASGD optimizer
Description: Test ASGD in Graph mode with another set of parameter
Expectation: Loss values and parameters conform to preset values.
"""
from .optimizer_utils import no_default_fc1_weight_asgd, \
no_default_fc1_bias_asgd, no_default_fc2_weight_asgd, no_default_fc2_bias_asgd
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
config = {'name': 'ASGD', 'lr': 0.001, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001}
loss, cells = build_network(config, is_group=True)
assert np.allclose(cells.ax[0].asnumpy(), no_default_fc1_weight_asgd, atol=1.e-5)
assert np.allclose(cells.ax[1].asnumpy(), no_default_fc1_bias_asgd, atol=1.e-5)
assert np.allclose(cells.ax[2].asnumpy(), no_default_fc2_weight_asgd, atol=1.e-5)
assert np.allclose(cells.ax[3].asnumpy(), no_default_fc2_bias_asgd, atol=1.e-5)
assert np.allclose(loss_not_default_asgd, loss, atol=1.e-5, rtol=1e-3)
def test_default_asgd_group_graph():
"""
Feature: Test ASGD optimizer
Description: Test ASGD in Graph mode with parameter grouping
Expectation: Loss values and parameters conform to preset values.
"""
from .optimizer_utils import no_default_group_fc1_weight_asgd, no_default_group_fc1_bias_asgd, \
no_default_group_fc2_weight_asgd, no_default_group_fc2_bias_asgd
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
config = {'name': 'ASGD', 'lr': 0.1, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001}
loss, cells = build_network(config, is_group=True)
assert np.allclose(cells.ax[0].asnumpy(), no_default_group_fc1_weight_asgd, atol=1.e-5)
assert np.allclose(cells.ax[1].asnumpy(), no_default_group_fc1_bias_asgd, atol=1.e-5)
assert np.allclose(cells.ax[2].asnumpy(), no_default_group_fc2_weight_asgd, atol=1.e-5)
assert np.allclose(cells.ax[3].asnumpy(), no_default_group_fc2_bias_asgd, atol=1.e-5)
assert np.allclose(loss_group_asgd, loss, atol=1.e-5, rtol=1e-3)

View File

@ -0,0 +1,125 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.context as context
from .optimizer_utils import build_network, loss_not_default_asgd, loss_default_asgd, loss_group_asgd
def test_default_asgd_pynative():
"""
Feature: Test ASGD optimizer
Description: Test ASGD in Pynative mode with default parameter
Expectation: Loss values and parameters conform to preset values.
"""
from .optimizer_utils import default_fc1_weight_asgd, \
default_fc1_bias_asgd, default_fc2_weight_asgd, default_fc2_bias_asgd
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
config = {'name': 'ASGD', 'lr': 0.01, 'lambd': 1e-4, 'alpha': 0.75, 't0': 1e6, 'weight_decay': 0.0}
loss, cells = build_network(config)
assert np.allclose(cells.ax[0].asnumpy(), default_fc1_weight_asgd, atol=1.e-5)
assert np.allclose(cells.ax[1].asnumpy(), default_fc1_bias_asgd, atol=1.e-5)
assert np.allclose(cells.ax[2].asnumpy(), default_fc2_weight_asgd, atol=1.e-5)
assert np.allclose(cells.ax[3].asnumpy(), default_fc2_bias_asgd, atol=1.e-5)
assert np.allclose(loss_default_asgd, loss, atol=1.e-5)
def test_default_asgd_graph():
"""
Feature: Test ASGD optimizer
Description: Test ASGD in Graph mode with default parameter
Expectation: Loss values and parameters conform to preset values.
"""
from .optimizer_utils import default_fc1_weight_asgd, \
default_fc1_bias_asgd, default_fc2_weight_asgd, default_fc2_bias_asgd
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
config = {'name': 'ASGD', 'lr': 0.01, 'lambd': 1e-4, 'alpha': 0.75, 't0': 1e6, 'weight_decay': 0.0}
loss, cells = build_network(config)
assert np.allclose(cells.ax[0].asnumpy(), default_fc1_weight_asgd, atol=1.e-5)
assert np.allclose(cells.ax[1].asnumpy(), default_fc1_bias_asgd, atol=1.e-5)
assert np.allclose(cells.ax[2].asnumpy(), default_fc2_weight_asgd, atol=1.e-5)
assert np.allclose(cells.ax[3].asnumpy(), default_fc2_bias_asgd, atol=1.e-5)
assert np.allclose(loss_default_asgd, loss, atol=1.e-5)
def test_no_default_asgd_pynative():
"""
Feature: Test ASGD optimizer
Description: Test ASGD in Pynative mode with another set of parameter
Expectation: Loss values and parameters conform to preset values.
"""
from .optimizer_utils import no_default_fc1_weight_asgd, \
no_default_fc1_bias_asgd, no_default_fc2_weight_asgd, no_default_fc2_bias_asgd
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
config = {'name': 'ASGD', 'lr': 0.001, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001}
loss, cells = build_network(config)
assert np.allclose(cells.ax[0].asnumpy(), no_default_fc1_weight_asgd, atol=1.e-5)
assert np.allclose(cells.ax[1].asnumpy(), no_default_fc1_bias_asgd, atol=1.e-5)
assert np.allclose(cells.ax[2].asnumpy(), no_default_fc2_weight_asgd, atol=1.e-5)
assert np.allclose(cells.ax[3].asnumpy(), no_default_fc2_bias_asgd, atol=1.e-5)
assert np.allclose(loss_not_default_asgd, loss, atol=1.e-5, rtol=1e-3)
def test_no_default_asgd_graph():
"""
Feature: Test ASGD optimizer
Description: Test ASGD in Graph mode with another set of parameter
Expectation: Loss values and parameters conform to preset values.
"""
from .optimizer_utils import no_default_fc1_weight_asgd, \
no_default_fc1_bias_asgd, no_default_fc2_weight_asgd, no_default_fc2_bias_asgd
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
config = {'name': 'ASGD', 'lr': 0.001, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001}
loss, cells = build_network(config)
assert np.allclose(cells.ax[0].asnumpy(), no_default_fc1_weight_asgd, atol=1.e-5)
assert np.allclose(cells.ax[1].asnumpy(), no_default_fc1_bias_asgd, atol=1.e-5)
assert np.allclose(cells.ax[2].asnumpy(), no_default_fc2_weight_asgd, atol=1.e-5)
assert np.allclose(cells.ax[3].asnumpy(), no_default_fc2_bias_asgd, atol=1.e-5)
assert np.allclose(loss_not_default_asgd, loss, atol=1.e-5, rtol=1e-3)
def test_default_asgd_group_pynative():
"""
Feature: Test ASGD optimizer
Description: Test ASGD in Pynative mode with parameter grouping
Expectation: Loss values and parameters conform to preset values.
"""
from .optimizer_utils import no_default_group_fc1_weight_asgd, no_default_group_fc1_bias_asgd, \
no_default_group_fc2_weight_asgd, no_default_group_fc2_bias_asgd
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
config = {'name': 'ASGD', 'lr': 0.1, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001}
loss, cells = build_network(config, is_group=True)
assert np.allclose(cells.ax[0].asnumpy(), no_default_group_fc1_weight_asgd, atol=1.e-5)
assert np.allclose(cells.ax[1].asnumpy(), no_default_group_fc1_bias_asgd, atol=1.e-5)
assert np.allclose(cells.ax[2].asnumpy(), no_default_group_fc2_weight_asgd, atol=1.e-5)
assert np.allclose(cells.ax[3].asnumpy(), no_default_group_fc2_bias_asgd, atol=1.e-5)
assert np.allclose(loss_group_asgd, loss, atol=1.e-5, rtol=1e-3)
def test_default_asgd_group_graph():
"""
Feature: Test ASGD optimizer
Description: Test ASGD in Graph mode with parameter grouping
Expectation: Loss values and parameters conform to preset values.
"""
from .optimizer_utils import no_default_group_fc1_weight_asgd, no_default_group_fc1_bias_asgd, \
no_default_group_fc2_weight_asgd, no_default_group_fc2_bias_asgd
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
config = {'name': 'ASGD', 'lr': 0.1, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001}
loss, cells = build_network(config, is_group=True)
assert np.allclose(cells.ax[0].asnumpy(), no_default_group_fc1_weight_asgd, atol=1.e-5)
assert np.allclose(cells.ax[1].asnumpy(), no_default_group_fc1_bias_asgd, atol=1.e-5)
assert np.allclose(cells.ax[2].asnumpy(), no_default_group_fc2_weight_asgd, atol=1.e-5)
assert np.allclose(cells.ax[3].asnumpy(), no_default_group_fc2_bias_asgd, atol=1.e-5)
assert np.allclose(loss_group_asgd, loss, atol=1.e-5, rtol=1e-3)

View File

@ -0,0 +1,90 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.context as context
from .optimizer_utils import build_network, loss_default_rprop, loss_group_rprop, loss_not_default_rprop
def test_default_rprop_pynative():
"""
Feature: Test Rprop optimizer
Description: Test Rprop in Pynative mode with default parameter
Expectation: Loss values and parameters conform to preset values.
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target='Ascend')
config = {'name': 'Rprop', 'lr': 0.01, 'etas': (0.5, 1.2), 'step_sizes': (1e-6, 50.), 'weight_decay': 0.0}
loss = build_network(config)
assert np.allclose(loss_default_rprop, loss, atol=1.e-5)
def test_default_rprop_graph():
"""
Feature: Test Rprop optimizer
Description: Test Rprop in Graph mode with default parameter
Expectation: Loss values and parameters conform to preset values.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
config = {'name': 'Rprop', 'lr': 0.01, 'etas': (0.5, 1.2), 'step_sizes': (1e-6, 50.), 'weight_decay': 0.0}
loss = build_network(config)
assert np.allclose(loss_default_rprop, loss, atol=1.e-5)
def test_no_default_rprop_pynative():
"""
Feature: Test Rprop optimizer
Description: Test Rprop in Pynative mode with another set of parameter
Expectation: Loss values and parameters conform to preset values.
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target='Ascend')
config = {'name': 'Rprop', 'lr': 0.001, 'etas': (0.6, 1.9), 'step_sizes': (1e-3, 20.), 'weight_decay': 0.0}
loss = build_network(config)
assert np.allclose(loss_not_default_rprop, loss, atol=1.e-5)
def test_no_default_rprop_graph():
"""
Feature: Test Rprop optimizer
Description: Test Rprop in Graph mode with another set of parameter
Expectation: Loss values and parameters conform to preset values.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
config = {'name': 'Rprop', 'lr': 0.001, 'etas': (0.6, 1.9), 'step_sizes': (1e-3, 20.), 'weight_decay': 0.0}
loss = build_network(config)
assert np.allclose(loss_not_default_rprop, loss, atol=1.e-5)
def test_default_rprop_group_pynative():
"""
Feature: Test Rprop optimizer
Description: Test Rprop in Pynative mode with parameter grouping
Expectation: Loss values and parameters conform to preset values.
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target='Ascend')
config = {'name': 'Rprop', 'lr': 0.001, 'etas': (0.6, 1.9), 'step_sizes': (1e-2, 10.), 'weight_decay': 0.0}
loss = build_network(config, is_group=True)
assert np.allclose(loss_group_rprop, loss, atol=1.e-5)
def test_default_rprop_group_graph():
"""
Feature: Test Rprop optimizer
Description: Test Rprop in Graph mode with parameter grouping
Expectation: Loss values and parameters conform to preset values.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
config = {'name': 'Rprop', 'lr': 0.001, 'etas': (0.6, 1.9), 'step_sizes': (1e-2, 10.), 'weight_decay': 0.0}
loss = build_network(config, is_group=True)
assert np.allclose(loss_group_rprop, loss, atol=1.e-5)

View File

@ -0,0 +1,54 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.context as context
from .optimizer_utils import build_network, loss_default_rprop, loss_group_rprop, loss_not_default_rprop
def test_default_rprop_graph():
"""
Feature: Test Rprop optimizer
Description: Test Rprop in Graph mode with default parameter
Expectation: Loss values and parameters conform to preset values.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
config = {'name': 'Rprop', 'lr': 0.01, 'etas': (0.5, 1.2), 'step_sizes': (1e-6, 50.), 'weight_decay': 0.0}
loss = build_network(config)
assert np.allclose(loss_default_rprop, loss, atol=1.e-5)
def test_no_default_rprop_graph():
"""
Feature: Test Rprop optimizer
Description: Test Rprop in Graph mode with another set of parameter
Expectation: Loss values and parameters conform to preset values.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
config = {'name': 'Rprop', 'lr': 0.001, 'etas': (0.6, 1.9), 'step_sizes': (1e-3, 20.), 'weight_decay': 0.0}
loss = build_network(config)
assert np.allclose(loss_not_default_rprop, loss, atol=1.e-5)
def test_default_rprop_group_graph():
"""
Feature: Test Rprop optimizer
Description: Test Rprop in Graph mode with parameter grouping
Expectation: Loss values and parameters conform to preset values.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
config = {'name': 'Rprop', 'lr': 0.001, 'etas': (0.6, 1.9), 'step_sizes': (1e-2, 10.), 'weight_decay': 0.0}
loss = build_network(config, is_group=True)
assert np.allclose(loss_group_rprop, loss, atol=1.e-5)

View File

@ -0,0 +1,90 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.context as context
from .optimizer_utils import build_network, loss_default_rprop, loss_group_rprop, loss_not_default_rprop
def test_default_rprop_pynative():
"""
Feature: Test Rprop optimizer
Description: Test Rprop in Pynative mode with default parameter
Expectation: Loss values and parameters conform to preset values.
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
config = {'name': 'Rprop', 'lr': 0.01, 'etas': (0.5, 1.2), 'step_sizes': (1e-6, 50.), 'weight_decay': 0.0}
loss = build_network(config)
assert np.allclose(loss_default_rprop, loss, atol=1.e-5)
def test_default_rprop_graph():
"""
Feature: Test Rprop optimizer
Description: Test Rprop in Graph mode with default parameter
Expectation: Loss values and parameters conform to preset values.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
config = {'name': 'Rprop', 'lr': 0.01, 'etas': (0.5, 1.2), 'step_sizes': (1e-6, 50.), 'weight_decay': 0.0}
loss = build_network(config)
assert np.allclose(loss_default_rprop, loss, atol=1.e-5)
def test_no_default_rprop_pynative():
"""
Feature: Test Rprop optimizer
Description: Test Rprop in Pynative mode with another set of parameter
Expectation: Loss values and parameters conform to preset values.
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
config = {'name': 'Rprop', 'lr': 0.001, 'etas': (0.6, 1.9), 'step_sizes': (1e-3, 20.), 'weight_decay': 0.0}
loss = build_network(config)
assert np.allclose(loss_not_default_rprop, loss, atol=1.e-5)
def test_no_default_rprop_graph():
"""
Feature: Test Rprop optimizer
Description: Test Rprop in Graph mode with another set of parameter
Expectation: Loss values and parameters conform to preset values.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
config = {'name': 'Rprop', 'lr': 0.001, 'etas': (0.6, 1.9), 'step_sizes': (1e-3, 20.), 'weight_decay': 0.0}
loss = build_network(config)
assert np.allclose(loss_not_default_rprop, loss, atol=1.e-5)
def test_default_rprop_group_pynative():
"""
Feature: Test Rprop optimizer
Description: Test Rprop in Pynative mode with parameter grouping
Expectation: Loss values and parameters conform to preset values.
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
config = {'name': 'Rprop', 'lr': 0.001, 'etas': (0.6, 1.9), 'step_sizes': (1e-2, 10.), 'weight_decay': 0.0}
loss = build_network(config, is_group=True)
assert np.allclose(loss_group_rprop, loss, atol=1.e-5)
def test_default_rprop_group_graph():
"""
Feature: Test Rprop optimizer
Description: Test Rprop in Graph mode with parameter grouping
Expectation: Loss values and parameters conform to preset values.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
config = {'name': 'Rprop', 'lr': 0.001, 'etas': (0.6, 1.9), 'step_sizes': (1e-2, 10.), 'weight_decay': 0.0}
loss = build_network(config, is_group=True)
assert np.allclose(loss_group_rprop, loss, atol=1.e-5)