!34093 add nn.adamax optimizer
Merge pull request !34093 from liutongtong9/add_adamax
This commit is contained in:
commit
8df3aa1c7e
|
@ -0,0 +1,60 @@
|
|||
.. py:class:: mindspore.nn.AdaMax(*args, **kwargs)
|
||||
|
||||
AdaMax算法是基于无穷范数的Adam的一种变体。
|
||||
|
||||
AdaMax算法详情请参阅论文 `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_。
|
||||
|
||||
公式如下:
|
||||
.. math::
|
||||
\begin{array}{ll} \\
|
||||
m_{t+1} = \beta_1 * m_{t} + (1 - \beta_1) * g \\
|
||||
v_{t+1} = \max(\beta_2 * v_{t}, \left| g \right|) \\
|
||||
w = w - \frac{l}{1 - \beta_1^{t+1}} * \frac{m_{t+1}}{v_{t+1} + \epsilon}
|
||||
\end{array}
|
||||
|
||||
:math:`m` 代表第一个动量矩阵,:math:`v` 代表第二个动量矩阵,:math:`g` 代表梯度 `gradients` ,:math:`\beta_1, \beta_2` 代表衰减速率 `beta1` 和 `beta2` ,:math:`t` 代表当前step,:math:`beta_1^t` 代表 `beta1` 的t次方 , :math:`\l` 代表学习率 `learning_rate` ,:math:`w` 代表 `params` , :math:`\epsilon` 代表 `eps` 。
|
||||
|
||||
.. note::
|
||||
|
||||
.. include:: mindspore.nn.optim_note_weight_decay.rst
|
||||
|
||||
**参数:**
|
||||
|
||||
- **params** (Union[list[Parameter], list[dict]]) - 必须是 `Parameter` 组成的列表或字典组成的列表。当列表元素是字典时,字典的键可以是"params"、"lr"、"weight_decay"、"grad_centralization"和"order_params":
|
||||
|
||||
.. include:: mindspore.nn.optim_group_param.rst
|
||||
.. include:: mindspore.nn.optim_group_lr.rst
|
||||
.. include:: mindspore.nn.optim_group_dynamic_weight_decay.rst
|
||||
.. include:: mindspore.nn.optim_group_gc.rst
|
||||
.. include:: mindspore.nn.optim_group_order.rst
|
||||
|
||||
- **learning_rate** (Union[float, Tensor, iterable, LearningRateSchedule]): 默认值:0.001。
|
||||
|
||||
.. include:: mindspore.nn.optim_arg_dynamic_lr.rst
|
||||
|
||||
- **beta1** (float) - 第一个动量矩阵的指数衰减率。参数范围(0.0,1.0)。默认值:0.9。
|
||||
- **beta2** (float) - 第二个动量矩阵的指数衰减率。参数范围(0.0,1.0)。默认值:0.999。
|
||||
- **eps** (float) - 加在分母上的值,以确保数值稳定。必须大于0。默认值:1e-8。
|
||||
- **weight_decay** (Union[float, int, Cell]) - 权重衰减(L2 penalty)。默认值:0.0。
|
||||
|
||||
.. include:: mindspore.nn.optim_arg_dynamic_wd.rst
|
||||
|
||||
.. include:: mindspore.nn.optim_arg_loss_scale.rst
|
||||
|
||||
**输入:**
|
||||
|
||||
- **gradients** (tuple[Tensor]) - `params` 的梯度,形状(shape)与 `params` 相同。
|
||||
|
||||
**输出:**
|
||||
|
||||
Tensor[bool],值为True。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `learning_rate` 不是int、float、Tensor、iterable或LearningRateSchedule。
|
||||
- **TypeError** - `parameters` 的元素不是Parameter或字典。
|
||||
- **TypeError** - `beta1` 、`beta2` 、 `eps` 或 `loss_scale` 不是float。
|
||||
- **TypeError** - `weight_decay` 不是float或int。
|
||||
- **ValueError** - `loss_scale` 或 `eps` 小于或等于0。
|
||||
- **ValueError** - `beta1` 、`beta2` 不在(0.0,1.0)范围内。
|
||||
- **ValueError** - `weight_decay` 小于0。
|
|
@ -34,7 +34,8 @@ from .ada_grad import Adagrad
|
|||
from .thor import thor
|
||||
from .adafactor import AdaFactor
|
||||
from .adasum import AdaSumByDeltaWeightWrapCell, AdaSumByGradWrapCell
|
||||
from .adamax import AdaMax
|
||||
|
||||
__all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'AdamWeightDecay', 'LazyAdam', 'AdamOffload',
|
||||
'Lamb', 'SGD', 'ASGD', 'Rprop', 'FTRL', 'RMSProp', 'ProximalAdagrad', 'Adagrad', 'thor', 'AdaFactor',
|
||||
'AdaSumByDeltaWeightWrapCell', 'AdaSumByGradWrapCell']
|
||||
'AdaSumByDeltaWeightWrapCell', 'AdaSumByGradWrapCell', 'AdaMax']
|
||||
|
|
|
@ -0,0 +1,211 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""adamax"""
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.nn.optim.optimizer import Optimizer
|
||||
from mindspore.nn.optim.optimizer import opt_init_args_register
|
||||
from mindspore._checkparam import Rel
|
||||
|
||||
_ada_max_opt = C.MultitypeFuncGraph("ada_max_opt")
|
||||
|
||||
|
||||
@_ada_max_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "Tensor")
|
||||
def _tensor_run_opt(opt, beta1, beta2, beta1_power, eps, learning_rate, weight, moment1, moment2, gradient):
|
||||
success = True
|
||||
success = F.depend(success, opt(weight, moment1, moment2, beta1_power, learning_rate, beta1, beta2, eps, gradient))
|
||||
return success
|
||||
|
||||
|
||||
def _check_param_value(beta1, beta2, eps, prim_name):
|
||||
"""Check the type of inputs."""
|
||||
validator.check_value_type("beta1", beta1, [float], prim_name)
|
||||
validator.check_value_type("beta2", beta2, [float], prim_name)
|
||||
validator.check_value_type("eps", eps, [float], prim_name)
|
||||
validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
|
||||
validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
|
||||
validator.check_positive_float(eps, "eps", prim_name)
|
||||
|
||||
|
||||
class AdaMax(Optimizer):
|
||||
r"""
|
||||
Implements the AdaMax algorithm, a variant of Adaptive Movement Estimation (Adam) based on the infinity norm.
|
||||
|
||||
The AdaMax algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
|
||||
|
||||
The updating formulas are as follows,
|
||||
|
||||
.. math::
|
||||
\begin{array}{ll} \\
|
||||
m_{t+1} = \beta_1 * m_{t} + (1 - \beta_1) * g \\
|
||||
v_{t+1} = \max(\beta_2 * v_{t}, \left| g \right|) \\
|
||||
w = w - \frac{l}{1 - \beta_1^{t+1}} * \frac{m_{t+1}}{v_{t+1} + \epsilon}
|
||||
\end{array}
|
||||
|
||||
:math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector,
|
||||
:math:`g` represents `gradients`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`,
|
||||
:math:`t` represents the current step, :math:`beta_1^t` represent `beta1_power`,
|
||||
:math:`\l` represents `learning_rate`, :math:`w` represents `params`,
|
||||
:math:`\epsilon` represents `eps`.
|
||||
|
||||
Note:
|
||||
If parameters are not grouped, the `weight_decay` in optimizer will be applied on the network parameters without
|
||||
'beta' or 'gamma' in their names. Users can group parameters to change the strategy of decaying weight. When
|
||||
parameters are grouped, each group can set `weight_decay`, if not, the `weight_decay` in optimizer will be
|
||||
applied.
|
||||
|
||||
Args:
|
||||
params (Union[list[Parameter], list[dict]]): Must be list of `Parameter` or list of `dict`. When the
|
||||
`params` is a list of `dict`, the string "params", "lr", "weight_decay", "grad_centralization" and
|
||||
"order_params" are the keys can be parsed.
|
||||
|
||||
- params: Required. Parameters in current group. The value must be a list of `Parameter`.
|
||||
|
||||
- lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
|
||||
If not, the `learning_rate` in optimizer will be used. Fixed and dynamic learning rate are supported.
|
||||
|
||||
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
|
||||
will be used. If not, the `weight_decay` in the optimizer will be used. It should be noted that weight
|
||||
decay can be a constant value or a Cell. It is a Cell only when dynamic weight decay is applied. Dynamic
|
||||
weight decay is similar to dynamic learning rate, users need to customize a weight decay schedule only
|
||||
with global step as input, and during training, the optimizer calls the instance of WeightDecaySchedule
|
||||
to get the weight decay value of current step.
|
||||
|
||||
- grad_centralization: Optional. Must be Boolean. If "grad_centralization" is in the keys, the set value
|
||||
will be used. If not, the `grad_centralization` is False by default. This configuration only works on the
|
||||
convolution layer.
|
||||
|
||||
- order_params: Optional. When parameters is grouped, this usually is used to maintain the order of
|
||||
parameters that appeared in the network to improve performance. The value should be parameters whose
|
||||
order will be followed in optimizer.
|
||||
If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
|
||||
one group of `params`.
|
||||
|
||||
learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Default: 0.001.
|
||||
|
||||
- float: The fixed learning rate value. Must be equal to or greater than 0.
|
||||
|
||||
- int: The fixed learning rate value. Must be equal to or greater than 0. It will be converted to float.
|
||||
|
||||
- Tensor: Its value should be a scalar or a 1-D vector. For scalar, fixed learning rate will be applied.
|
||||
For vector, learning rate is dynamic, then the i-th step will take the i-th value as the learning rate.
|
||||
|
||||
- Iterable: Learning rate is dynamic. The i-th step will take the i-th value as the learning rate.
|
||||
|
||||
- LearningRateSchedule: Learning rate is dynamic. During training, the optimizer calls the instance of
|
||||
LearningRateSchedule with step as the input to get the learning rate of current step.
|
||||
|
||||
beta1 (float): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
|
||||
Default: 0.9.
|
||||
beta2 (float): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
|
||||
Default: 0.999.
|
||||
eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default:
|
||||
1e-8.
|
||||
|
||||
weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: 0.0.
|
||||
|
||||
- float: The fixed weight decay value. Must be equal to or greater than 0.
|
||||
|
||||
- int: The fixed weight decay value. Must be equal to or greater than 0. It will be converted to float.
|
||||
|
||||
- Cell: Weight decay is dynamic. During training, the optimizer calls the instance of
|
||||
the Cell with step as the input to get the weight decay value of current step.
|
||||
|
||||
loss_scale (float): A floating point value for the loss scale. Should be greater than 0. In general, use the
|
||||
default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
|
||||
`FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
|
||||
`FixedLossScaleManager`. Refer to class :class:`mindspore.FixedLossScaleManager` for more details.
|
||||
Default: 1.0.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
|
||||
Outputs:
|
||||
Tensor[bool], the value is True.
|
||||
|
||||
Raises:
|
||||
TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.
|
||||
TypeError: If element of `parameters` is neither Parameter nor dict.
|
||||
TypeError: If `beta1`, `beta2`, `eps` or `loss_scale` is not a float.
|
||||
TypeError: If `weight_decay` is neither float nor int.
|
||||
ValueError: If `loss_scale` or `eps` is less than or equal to 0.
|
||||
ValueError: If `beta1`, `beta2` is not in range (0.0, 1.0).
|
||||
ValueError: If `weight_decay` is less than 0.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import nn, Model
|
||||
>>>
|
||||
>>> net = Net()
|
||||
>>> #1) All parameters use the same learning rate and weight decay
|
||||
>>> optim = nn.AdaMax(params=net.trainable_params())
|
||||
>>>
|
||||
>>> #2) Use parameter groups and set different values
|
||||
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
|
||||
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
|
||||
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True},
|
||||
... {'params': no_conv_params, 'lr': 0.01},
|
||||
... {'order_params': net.trainable_params()}]
|
||||
>>> optim = nn.AdaMax(group_params, learning_rate=0.1, weight_decay=0.0)
|
||||
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad
|
||||
>>> # centralization of True.
|
||||
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0 and grad
|
||||
>>> # centralization of False.
|
||||
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
|
||||
>>>
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim)
|
||||
"""
|
||||
@opt_init_args_register
|
||||
def __init__(self, params, learning_rate=0.001, beta1=0.9, beta2=0.999, eps=1e-08,
|
||||
weight_decay=0.0, loss_scale=1.0):
|
||||
super(AdaMax, self).__init__(learning_rate, params, weight_decay, loss_scale)
|
||||
_check_param_value(beta1, beta2, eps, self.cls_name)
|
||||
|
||||
self.beta1 = Tensor(beta1, mstype.float32)
|
||||
self.beta2 = Tensor(beta2, mstype.float32)
|
||||
self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power")
|
||||
self.eps = Tensor(eps, mstype.float32)
|
||||
self.moment1 = self._parameters.clone(prefix="moment1", init='zeros')
|
||||
self.moment2 = self._parameters.clone(prefix="moment2", init='zeros')
|
||||
|
||||
self.opt = P.ApplyAdaMax()
|
||||
|
||||
def construct(self, gradients):
|
||||
gradients = self.flatten_gradients(gradients)
|
||||
gradients = self.decay_weight(gradients)
|
||||
gradients = self.gradients_centralization(gradients)
|
||||
gradients = self.scale_grad(gradients)
|
||||
lr = self.get_lr()
|
||||
|
||||
self.beta1_power *= self.beta1
|
||||
|
||||
if self.is_group_lr:
|
||||
success = self.map_(F.partial(_ada_max_opt, self.opt, self.beta1, self.beta2, self.beta1_power, self.eps),
|
||||
lr, self._parameters, self.moment1, self.moment2, gradients)
|
||||
else:
|
||||
success = self.map_(F.partial(_ada_max_opt, self.opt, self.beta1, self.beta2, self.beta1_power,
|
||||
self.eps, lr), self._parameters, self.moment1, self.moment2, gradients)
|
||||
|
||||
return success
|
|
@ -19,6 +19,7 @@ from mindspore import nn, Tensor
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.nn.optim import ASGD
|
||||
from mindspore.nn.optim import Rprop
|
||||
from mindspore.nn.optim import AdaMax
|
||||
np.random.seed(1024)
|
||||
|
||||
fc1_weight = np.array([[0.72346634, 0.95608497, 0.4084163, 0.18627149,
|
||||
|
@ -52,10 +53,10 @@ class NetWithLoss(nn.Cell):
|
|||
"""
|
||||
build net with loss
|
||||
"""
|
||||
def __init__(self, network):
|
||||
def __init__(self, network, loss_fn):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.network = network
|
||||
self.loss = nn.MSELoss(reduction='sum')
|
||||
self.loss = loss_fn
|
||||
|
||||
def construct(self, x, label):
|
||||
out = self.network(x)
|
||||
|
@ -93,14 +94,13 @@ class FakeNet(nn.Cell):
|
|||
m.bias.set_data(Tensor(fc2_bias))
|
||||
|
||||
|
||||
def build_network(opt_config, is_group=False):
|
||||
def build_network(opt_config, is_group=False, net=FakeNet(), loss_fn=nn.MSELoss(reduction='sum')):
|
||||
"""
|
||||
Construct training
|
||||
"""
|
||||
losses = []
|
||||
net = FakeNet()
|
||||
|
||||
networkwithloss = NetWithLoss(net)
|
||||
networkwithloss = NetWithLoss(net, loss_fn)
|
||||
networkwithloss.set_train()
|
||||
|
||||
if is_group:
|
||||
|
@ -108,6 +108,8 @@ def build_network(opt_config, is_group=False):
|
|||
fc2_params = list(filter(lambda x: 'fc1' not in x.name, networkwithloss.trainable_params()))
|
||||
if opt_config['name'] == 'ASGD':
|
||||
params = [{'params': fc1_params, 'weight_decay': 0.01, 'lr': 0.001}, {'params': fc2_params, 'lr': 0.1}]
|
||||
elif opt_config['name'] == 'adamax':
|
||||
params = [{'params': fc1_params, 'lr': 0.0018}, {'params': fc2_params, 'lr': 0.0022}]
|
||||
else:
|
||||
params = [{'params': fc1_params, 'lr': 0.001}, {'params': fc2_params, 'lr': 0.1}]
|
||||
else:
|
||||
|
@ -121,6 +123,10 @@ def build_network(opt_config, is_group=False):
|
|||
net_opt = Rprop(params, learning_rate=opt_config['lr'], etas=opt_config['etas'],
|
||||
step_sizes=opt_config['step_sizes'], weight_decay=0.0)
|
||||
|
||||
elif opt_config['name'] == 'adamax':
|
||||
net_opt = AdaMax(params, learning_rate=opt_config['lr'], beta1=opt_config['beta1'],
|
||||
beta2=opt_config['beta2'], eps=opt_config['eps'], weight_decay=0.0)
|
||||
|
||||
trainonestepcell = mindspore.nn.TrainOneStepCell(networkwithloss, net_opt)
|
||||
data, label = make_fake_data()
|
||||
for i in range(20):
|
||||
|
@ -168,6 +174,21 @@ loss_group_rprop = np.array([3.0124679e-01, 7.1360558e+01, 4.8910957e+01, 2.1730
|
|||
2.4236647e+01, 3.9299741e+02, 3.5600668e+02, 1.4759110e+01,
|
||||
7.2244568e+02, 8.1952783e+02, 9.8913864e+01, 1.1141744e+03], dtype=np.float32)
|
||||
|
||||
loss_default_adamax = np.array([1.0, 4.542382, 10.5303135, 18.87176, 29.475002,
|
||||
42.2471, 57.09358, 73.917595, 92.62038, 113.10096,
|
||||
135.25633, 158.9815, 184.16951, 210.71207, 238.49873,
|
||||
267.41818, 297.35782, 328.20422, 359.84293, 392.15878], dtype=np.float32)
|
||||
|
||||
loss_not_default_adamax = np.array([1.0, 4.5040994, 9.420462, 14.951918, 20.390736,
|
||||
25.111732, 28.57695, 30.347034, 30.098299, 27.647425,
|
||||
22.994541, 16.402872, 8.979612, 2.7966619, 0.025522191,
|
||||
1.9826386, 8.12521, 15.100327, 18.94126, 19.657328], dtype=np.float32)
|
||||
|
||||
loss_group_adamax = np.array([1.0, 4.537268, 10.415594, 18.463926, 28.51337,
|
||||
40.394474, 53.936195, 68.9657, 85.307945, 102.78646,
|
||||
121.22308, 140.4386, 160.25333, 180.48737, 200.96124,
|
||||
221.49626, 241.91531, 262.0436, 281.70914, 300.7426], dtype=np.float32)
|
||||
|
||||
|
||||
default_fc1_weight_asgd = np.array([[-0.9451941, -0.71258026, -1.2602371, -1.4823773,
|
||||
-0.974408, -1.2709816, -1.4194703, -1.2137808],
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
from mindspore import nn, Tensor
|
||||
from optimizer_utils import build_network, \
|
||||
loss_default_adamax, loss_not_default_adamax, loss_group_adamax
|
||||
|
||||
|
||||
w1 = np.array([[0.03909272, 0.08893055, -0.259909, -0.459185,
|
||||
-0.0195536, 0.12977135, -0.62942827, -0.53132117],
|
||||
[0.1542052, 0.6513571, -0.06453168, 0.44788414,
|
||||
-0.3775454, 0.6520292, 0.444174, -0.59306043],
|
||||
[0.2712369, 0.20890862, 0.6859066, 0.6629662,
|
||||
0.4724893, -0.34384444, -0.16007674, 0.21797538],
|
||||
[-0.3865972, 0.26727962, 0.23178828, -0.24629539,
|
||||
-0.68038213, -0.31262863, 0.10493469, -0.28973007]]).astype("float32")
|
||||
|
||||
b1 = np.array([0., 0., 0., 0.]).astype("float32")
|
||||
|
||||
w2 = np.array([[-0.6079024, -1.005364, 0.59004724, 0.7289244]]).astype("float32")
|
||||
|
||||
b2 = np.array([0.]).astype("float32")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
"""
|
||||
build a 2-layers net to test adamax optimizer
|
||||
"""
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.fc1 = nn.Dense(8, 4, weight_init=Tensor(w1), bias_init=Tensor(b1))
|
||||
self.fc2 = nn.Dense(4, 1, weight_init=Tensor(w2), bias_init=Tensor(b2))
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.relu(self.fc1(x))
|
||||
return self.fc2(x)
|
||||
|
||||
|
||||
def test_default_adamax_pynative():
|
||||
"""
|
||||
Feature: Test adamax optimizer
|
||||
Description: Test adamax in Pynative mode with default parameter
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='Ascend')
|
||||
config = {'name': 'adamax', 'lr': 0.001, "beta1": 0.9, "beta2": 0.999, "eps": 1e-07,
|
||||
'weight_decay': 0.0}
|
||||
loss = build_network(config, net=Net(), loss_fn=nn.MSELoss(reduction='mean'))
|
||||
assert np.allclose(loss_default_adamax, loss, atol=1.e-5)
|
||||
|
||||
|
||||
def test_default_adamax_graph():
|
||||
"""
|
||||
Feature: Test adamax optimizer
|
||||
Description: Test adamax in Graph mode with default parameter
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
|
||||
config = {'name': 'adamax', 'lr': 0.001, "beta1": 0.9, "beta2": 0.999, "eps": 1e-07,
|
||||
'weight_decay': 0.0}
|
||||
loss = build_network(config, net=Net(), loss_fn=nn.MSELoss(reduction='mean'))
|
||||
assert np.allclose(loss_default_adamax, loss, atol=1.e-5)
|
||||
|
||||
|
||||
def test_no_default_adamax_pynative():
|
||||
"""
|
||||
Feature: Test adamax optimizer
|
||||
Description: Test adamax in Pynative mode with another set of parameter
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='Ascend')
|
||||
config = {'name': 'adamax', 'lr': 0.01, "beta1": 0.9, "beta2": 0.98, "eps": 1e-06,
|
||||
'weight_decay': 0.0}
|
||||
loss = build_network(config, net=Net(), loss_fn=nn.MSELoss(reduction='mean'))
|
||||
assert np.allclose(loss_not_default_adamax, loss, atol=1.e-5)
|
||||
|
||||
|
||||
def test_no_default_adamax_graph():
|
||||
"""
|
||||
Feature: Test adamax optimizer
|
||||
Description: Test adamax in Graph mode with another set of parameter
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
|
||||
config = {'name': 'adamax', 'lr': 0.01, "beta1": 0.9, "beta2": 0.98, "eps": 1e-06,
|
||||
'weight_decay': 0.0}
|
||||
loss = build_network(config, net=Net(), loss_fn=nn.MSELoss(reduction='mean'))
|
||||
assert np.allclose(loss_not_default_adamax, loss, atol=1.e-5)
|
||||
|
||||
|
||||
def test_default_adamax_group_pynative():
|
||||
"""
|
||||
Feature: Test adamax optimizer
|
||||
Description: Test adamax in Pynative mode with parameter grouping
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='Ascend')
|
||||
config = {'name': 'adamax', 'lr': 0.002, "beta1": 0.9, "beta2": 0.999, "eps": 1e-08,
|
||||
'weight_decay': 0.0}
|
||||
loss = build_network(config, is_group=True, net=Net(), loss_fn=nn.MSELoss(reduction='mean'))
|
||||
assert np.allclose(loss_group_adamax, loss, atol=1.e-5)
|
||||
|
||||
|
||||
def test_default_adamax_group_graph():
|
||||
"""
|
||||
Feature: Test adamax optimizer
|
||||
Description: Test adamax in Graph mode with parameter grouping
|
||||
Expectation: Loss values and parameters conform to preset values.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
|
||||
config = {'name': 'adamax', 'lr': 0.002, "beta1": 0.9, "beta2": 0.999, "eps": 1e-08,
|
||||
'weight_decay': 0.0}
|
||||
loss = build_network(config, is_group=True, net=Net(), loss_fn=nn.MSELoss(reduction='mean'))
|
||||
assert np.allclose(loss_group_adamax, loss, atol=1.e-5)
|
Loading…
Reference in New Issue