forked from mindspore-Ecosystem/mindspore
!2007 add lazy adam optimizer and support sparse adam&ftrl for cpu backend
Merge pull request !2007 from wangnan39/add_lazy_adam_optim_and_support_sparse_admm_for_cpu_backend
This commit is contained in:
commit
3536185f5b
|
@ -27,6 +27,7 @@ from .lars import LARS
|
|||
from .ftrl import FTRL
|
||||
from .rmsprop import RMSProp
|
||||
from .proximal_ada_grad import ProximalAdagrad
|
||||
from .lazyadam import LazyAdam
|
||||
|
||||
__all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'AdamWeightDecay',
|
||||
__all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'AdamWeightDecay', 'LazyAdam',
|
||||
'AdamWeightDecayDynamicLR', 'Lamb', 'SGD', 'FTRL', 'RMSProp', 'ProximalAdagrad']
|
||||
|
|
|
@ -101,10 +101,21 @@ def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, po
|
|||
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name)
|
||||
|
||||
|
||||
@adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor")
|
||||
def _run_opt_with_one_number(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, moment1,
|
||||
moment2):
|
||||
@adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tuple",
|
||||
"Tensor", "Tensor", "Tensor")
|
||||
def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params,
|
||||
moment1, moment2):
|
||||
"""Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
|
||||
success = True
|
||||
success = F.depend(success, sparse_opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
|
||||
eps, gradient[1], gradient[0]))
|
||||
return success
|
||||
|
||||
|
||||
@adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor",
|
||||
"Tensor", "Tensor", "Tensor")
|
||||
def _run_opt_with_one_number(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params,
|
||||
moment1, moment2):
|
||||
"""Apply adam optimizer to the weight parameter using Tensor."""
|
||||
success = True
|
||||
success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
|
||||
|
@ -144,6 +155,10 @@ class Adam(Optimizer):
|
|||
|
||||
To improve parameter groups performance, the customized order of parameters can be supported.
|
||||
|
||||
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
|
||||
`sparse_grad` of `Parameter` being set as True. The sparse feature is under continuous development. The sparse
|
||||
behavior is currently performed on the CPU, weight decay and loss scale is not supported.
|
||||
|
||||
Args:
|
||||
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
|
||||
the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params",
|
||||
|
@ -232,12 +247,9 @@ class Adam(Optimizer):
|
|||
self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
|
||||
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.map_ = C.Map()
|
||||
self.opt = P.Adam(use_locking, use_nesterov)
|
||||
|
||||
self.pow = P.Pow()
|
||||
self.sqrt = P.Sqrt()
|
||||
self.one = Tensor(np.array([1.0]).astype(np.float32))
|
||||
self.realdiv = P.RealDiv()
|
||||
self.sparse_opt = P.SparseApplyAdam()
|
||||
|
||||
def construct(self, gradients):
|
||||
params = self.parameters
|
||||
|
@ -252,13 +264,13 @@ class Adam(Optimizer):
|
|||
beta2_power = self.beta2_power * self.beta2
|
||||
self.beta2_power = beta2_power
|
||||
if self.is_group_lr:
|
||||
success = self.hyper_map(F.partial(adam_opt, self.opt, beta1_power, beta2_power, self.beta1,
|
||||
self.beta2, self.eps),
|
||||
lr, gradients, params, moment1, moment2)
|
||||
success = self.map_(F.partial(adam_opt, self.opt, self.sparse_opt, beta1_power, beta2_power,
|
||||
self.beta1, self.beta2, self.eps),
|
||||
lr, gradients, params, moment1, moment2)
|
||||
else:
|
||||
success = self.hyper_map(F.partial(adam_opt, self.opt, beta1_power, beta2_power, self.beta1,
|
||||
self.beta2, self.eps, lr),
|
||||
gradients, params, moment1, moment2)
|
||||
success = self.map_(F.partial(adam_opt, self.opt, self.sparse_opt, beta1_power, beta2_power,
|
||||
self.beta1, self.beta2, self.eps, lr),
|
||||
gradients, params, moment1, moment2)
|
||||
return success
|
||||
|
||||
|
||||
|
|
|
@ -23,8 +23,18 @@ from .optimizer import Optimizer, apply_decay, grad_scale
|
|||
ftrl_opt = C.MultitypeFuncGraph("ftrl_opt")
|
||||
|
||||
|
||||
@ftrl_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor")
|
||||
def _tensor_run_opt(opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment):
|
||||
@ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tuple", "Tensor",
|
||||
"Tensor")
|
||||
def _tensor_run_opt_with_sparse(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment):
|
||||
"""Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse."""
|
||||
success = True
|
||||
success = F.depend(success, spars_opt(weight, moment, linear, gradient[1], gradient[0]))
|
||||
return success
|
||||
|
||||
|
||||
@ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor")
|
||||
def _tensor_run_opt(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment):
|
||||
"""Apply ftrl optimizer to the weight parameter."""
|
||||
success = True
|
||||
success = F.depend(success, opt(weight, moment, linear, gradient, learning_rate, l1, l2, lr_power))
|
||||
|
@ -67,6 +77,11 @@ class FTRL(Optimizer):
|
|||
<https://arxiv.org/abs/1002.4908>`_. Refer to paper `Ad Click Prediction: a View from the Trenches
|
||||
<https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf>`_ for engineering document.
|
||||
|
||||
Note:
|
||||
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
|
||||
`sparse_grad` of `Parameter` being set as True. The sparse feature is under continuous development. The sparse
|
||||
behavior is currently performed on the CPU, weight decay and loss scale is not supported.
|
||||
|
||||
Args:
|
||||
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
|
||||
should be Parameter.
|
||||
|
@ -109,8 +124,9 @@ class FTRL(Optimizer):
|
|||
self.weight_decay = weight_decay
|
||||
self.decay_tf = tuple((lambda: True)() for x in self.parameters)
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.map_ = C.Map()
|
||||
self.opt = P.ApplyFtrl(use_locking=use_locking)
|
||||
self.one = Tensor(1, mstype.int32)
|
||||
self.sparse_opt = P.SparseApplyFtrl(learning_rate, l1, l2, lr_power, use_locking=use_locking)
|
||||
|
||||
def construct(self, grads):
|
||||
params = self.parameters
|
||||
|
@ -121,6 +137,6 @@ class FTRL(Optimizer):
|
|||
if self.reciprocal_scale != 1.0:
|
||||
grads = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), grads)
|
||||
lr = self.learning_rate
|
||||
success = self.hyper_map(F.partial(ftrl_opt, self.opt, lr, self.l1, self.l2, self.lr_power),
|
||||
linear, grads, params, moments)
|
||||
success = self.map_(F.partial(ftrl_opt, self.opt, self.sparse_opt, lr, self.l1, self.l2, self.lr_power),
|
||||
linear, grads, params, moments)
|
||||
return success
|
||||
|
|
|
@ -0,0 +1,202 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""lazy adam"""
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from .optimizer import Optimizer
|
||||
|
||||
lazy_adam_opt = C.MultitypeFuncGraph("lazy_adam_opt")
|
||||
|
||||
|
||||
@lazy_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tuple",
|
||||
"Tensor", "Tensor", "Tensor")
|
||||
def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params,
|
||||
moment1, moment2):
|
||||
"""Apply sparse lazy adam optimizer to the weight parameter when the gradient is sparse."""
|
||||
success = True
|
||||
success = F.depend(success, sparse_opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
|
||||
eps, gradient[1], gradient[0]))
|
||||
return success
|
||||
|
||||
|
||||
@lazy_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor",
|
||||
"Tensor", "Tensor", "Tensor")
|
||||
def _run_opt_with_one_number(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params,
|
||||
moment1, moment2):
|
||||
"""Apply adam optimizer to the weight parameter using Tensor."""
|
||||
success = True
|
||||
success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
|
||||
eps, gradient))
|
||||
return success
|
||||
|
||||
|
||||
def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
|
||||
"""Check the type of inputs."""
|
||||
validator.check_value_type("beta1", beta1, [float], prim_name)
|
||||
validator.check_value_type("beta2", beta2, [float], prim_name)
|
||||
validator.check_value_type("eps", eps, [float], prim_name)
|
||||
validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
|
||||
validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
|
||||
validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
|
||||
validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
|
||||
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
|
||||
|
||||
|
||||
class LazyAdam(Optimizer):
|
||||
r"""
|
||||
Updates gradients by Adaptive Moment Estimation (Adam) algorithm.
|
||||
|
||||
The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
|
||||
|
||||
The updating formulas are as follows,
|
||||
|
||||
.. math::
|
||||
\begin{array}{ll} \\
|
||||
m = \beta_1 * m + (1 - \beta_1) * g \\
|
||||
v = \beta_2 * v + (1 - \beta_2) * g * g \\
|
||||
l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\
|
||||
w = w - l * \frac{m}{\sqrt{v} + \epsilon}
|
||||
\end{array}
|
||||
|
||||
:math:`m` represents the 1st moment vector `moment1`, :math:`v` represents the 2nd moment vector `moment2`,
|
||||
:math:`g` represents `gradients`, :math:`l` represents scaling factor `lr`, :math:`\beta_1, \beta_2` represent
|
||||
`beta1` and `beta2`, :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent
|
||||
`beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `params`,
|
||||
:math:`\epsilon` represents `eps`.
|
||||
|
||||
Note:
|
||||
The LazyAdam optimizer supports separating parameter groups. Different parameter groups can set different
|
||||
`learning_rate` and `weight_decay`.
|
||||
|
||||
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
|
||||
value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be
|
||||
applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters.
|
||||
|
||||
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
|
||||
`sparse_grad` of `Parameter` being set as True. The sparse behavior, to be notice, is not equivalent to the
|
||||
original Adam algorithm, as only the current indices parames will be updated. The sparse feature is under
|
||||
continuous development. The sparse behavior is currently performed on the CPU, weight decay and loss scale is
|
||||
not supported.
|
||||
|
||||
Args:
|
||||
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
|
||||
the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params",
|
||||
"lr" and "weight_decay" are the keys can be parsed.
|
||||
|
||||
- params: Required. The value should be a list of `Parameter`.
|
||||
|
||||
- lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
|
||||
If not, the `learning_rate` in the API will be used.
|
||||
|
||||
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
|
||||
will be used. If not, the `weight_decay` in the API will be used.
|
||||
|
||||
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
|
||||
Iterable or a Tensor and the dims of the Tensor is 1,
|
||||
use dynamic learning rate, then the i-th step will
|
||||
take the i-th value as the learning rate.
|
||||
When the learning_rate is float or learning_rate is a Tensor
|
||||
but the dims of the Tensor is 0, use fixed learning rate.
|
||||
Other cases are not supported. Default: 1e-3.
|
||||
beta1 (float): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0). Default:
|
||||
0.9.
|
||||
beta2 (float): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). Default:
|
||||
0.999.
|
||||
eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default:
|
||||
1e-8.
|
||||
use_locking (bool): Whether to enable a lock to protect updating variable tensors.
|
||||
If True, updating of the var, m, and v tensors will be protected by a lock.
|
||||
If False, the result is unpredictable. Default: False.
|
||||
use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
|
||||
If True, updates the gradients using NAG.
|
||||
If False, updates the gradients without using NAG. Default: False.
|
||||
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
|
||||
loss_scale (float): A floating point value for the loss scale. Should be equal to or greater than 1. Default:
|
||||
1.0.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
|
||||
Outputs:
|
||||
Tensor[bool], the value is True.
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> #1) All parameters use the same learning rate and weight decay
|
||||
>>> optim = nn.LazyAdam(params=net.trainable_params())
|
||||
>>>
|
||||
>>> #2) Use parameter groups and set different values
|
||||
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
|
||||
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
|
||||
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'lr': 0.01},
|
||||
>>> {'params': no_conv_params}]
|
||||
>>> opt = nn.LazyAdam(group_params, learning_rate=0.1, weight_decay=0.0)
|
||||
>>> # the conv_params's parameters will use a learning rate of 0.01 and a weight decay of 0.01
|
||||
>>> # the no_cov_params's parameters don't set learning and weight decay. So they will use a
|
||||
>>> # learning rate of 0.1 and a weight decay of 0.0.
|
||||
>>>
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim)
|
||||
"""
|
||||
|
||||
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
|
||||
use_nesterov=False, weight_decay=0.0, loss_scale=1.0):
|
||||
super(LazyAdam, self).__init__(learning_rate, params, weight_decay, loss_scale)
|
||||
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
|
||||
validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
|
||||
validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name)
|
||||
validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name)
|
||||
validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT, self.cls_name)
|
||||
|
||||
self.beta1 = Tensor(beta1, mstype.float32)
|
||||
self.beta2 = Tensor(beta2, mstype.float32)
|
||||
self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power")
|
||||
self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power")
|
||||
self.eps = eps
|
||||
self.use_nesterov = use_nesterov
|
||||
self.use_locking = use_locking
|
||||
|
||||
self.moment1 = self.parameters.clone(prefix="moment1", init='zeros')
|
||||
self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
|
||||
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.map_ = C.Map()
|
||||
self.opt = P.Adam(use_locking, use_nesterov)
|
||||
self.sparse_opt = P.SparseApplyLazyAdam(use_locking, use_nesterov)
|
||||
|
||||
def construct(self, gradients):
|
||||
gradients = self.decay_weight(gradients)
|
||||
gradients = self.scale_grad(gradients)
|
||||
lr = self.get_lr()
|
||||
|
||||
self.beta1_power = self.beta1_power * self.beta1
|
||||
self.beta2_power = self.beta2_power * self.beta2
|
||||
|
||||
if self.is_group_lr:
|
||||
success = self.map_(F.partial(lazy_adam_opt, self.opt, self.sparse_opt, self.beta1_power,
|
||||
self.beta2_power, self.beta1, self.beta2, self.eps),
|
||||
lr, gradients, self.parameters, self.moment1, self.moment2)
|
||||
else:
|
||||
success = self.map_(F.partial(lazy_adam_opt, self.opt, self.sparse_opt, self.beta1_power,
|
||||
self.beta2_power, self.beta1, self.beta2, self.eps, lr),
|
||||
gradients, self.parameters, self.moment1, self.moment2)
|
||||
return success
|
|
@ -21,7 +21,7 @@ from mindspore import Tensor, Parameter
|
|||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.nn.optim import AdamWeightDecay, AdamWeightDecayDynamicLR
|
||||
from mindspore.nn.optim import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
|
@ -50,6 +50,19 @@ class NetWithoutWeight(nn.Cell):
|
|||
return x
|
||||
|
||||
|
||||
class NetWithSparseGatherV2(nn.Cell):
|
||||
""" NetWithSparseGatherV2 definition """
|
||||
def __init__(self):
|
||||
super(NetWithSparseGatherV2, self).__init__()
|
||||
self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1", sparse_grad=True)
|
||||
self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2")
|
||||
self.axis = 0
|
||||
self.gather = P.SparseGatherV2()
|
||||
|
||||
def construct(self, indices, label):
|
||||
return self.gather(self.weight1, indices, self.axis) + self.weight2
|
||||
|
||||
|
||||
def test_adamwithoutparam():
|
||||
net = NetWithoutWeight()
|
||||
net.set_train()
|
||||
|
@ -72,6 +85,33 @@ def test_adamw_compile():
|
|||
_executor.compile(train_network, inputs, label)
|
||||
|
||||
|
||||
def test_adam_compile():
|
||||
""" test adam compile """
|
||||
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
|
||||
label = Tensor(np.zeros([1, 10]).astype(np.float32))
|
||||
net = Net()
|
||||
net.set_train()
|
||||
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
optimizer = Adam(net.trainable_params(), learning_rate=0.1, weight_decay=0.9)
|
||||
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||
_executor.compile(train_network, inputs, label)
|
||||
|
||||
|
||||
def test_spares_adam_compile():
|
||||
""" test_sparse_adam_compile """
|
||||
indices = Tensor(np.array([0, 1]).astype(np.int32))
|
||||
label = Tensor(np.zeros([2, 1, 2]).astype(np.float32))
|
||||
net = NetWithSparseGatherV2()
|
||||
net.set_train()
|
||||
|
||||
optimizer = Adam(net.trainable_params(), learning_rate=0.1)
|
||||
train_network = TrainOneStepCell(net, optimizer)
|
||||
_executor.compile(train_network, indices, label)
|
||||
|
||||
|
||||
def test_AdamWeightDecay_beta1():
|
||||
net = Net()
|
||||
print("**********", net.get_parameters())
|
||||
|
|
|
@ -37,6 +37,19 @@ class Net(nn.Cell):
|
|||
return x
|
||||
|
||||
|
||||
class NetWithSparseGatherV2(nn.Cell):
|
||||
""" NetWithSparseGatherV2 definition """
|
||||
def __init__(self):
|
||||
super(NetWithSparseGatherV2, self).__init__()
|
||||
self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1", sparse_grad=True)
|
||||
self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2")
|
||||
self.axis = 0
|
||||
self.gather = P.SparseGatherV2()
|
||||
|
||||
def construct(self, indices, label):
|
||||
return self.gather(self.weight1, indices, self.axis) + self.weight2
|
||||
|
||||
|
||||
def test_ftrl():
|
||||
""" test_ftrl """
|
||||
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
|
||||
|
@ -48,3 +61,15 @@ def test_ftrl():
|
|||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||
_executor.compile(train_network, inputs, label)
|
||||
|
||||
|
||||
def test_spares_ftrl_compile():
|
||||
""" test sparse ftrl compile """
|
||||
indices = Tensor(np.array([0, 1]).astype(np.int32))
|
||||
label = Tensor(np.zeros([2, 1, 2]).astype(np.float32))
|
||||
net = NetWithSparseGatherV2()
|
||||
net.set_train()
|
||||
|
||||
optimizer = FTRL(net.trainable_params())
|
||||
train_network = TrainOneStepCell(net, optimizer)
|
||||
_executor.compile(train_network, indices, label)
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test lazy adam """
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.nn.optim import LazyAdam
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
""" Net definition """
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.weight = Parameter(Tensor(np.ones([64, 10]).astype(np.float32)), name="weight")
|
||||
self.bias = Parameter(Tensor(np.ones([10]).astype((np.float32))), name="bias")
|
||||
self.matmul = P.MatMul()
|
||||
self.biasAdd = P.BiasAdd()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.biasAdd(self.matmul(x, self.weight), self.bias)
|
||||
return x
|
||||
|
||||
|
||||
class NetWithSparseGatherV2(nn.Cell):
|
||||
""" NetWithSparseGatherV2 definition """
|
||||
def __init__(self):
|
||||
super(NetWithSparseGatherV2, self).__init__()
|
||||
self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1", sparse_grad=True)
|
||||
self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2")
|
||||
self.axis = 0
|
||||
self.gather = P.SparseGatherV2()
|
||||
|
||||
def construct(self, indices, label):
|
||||
return self.gather(self.weight1, indices, self.axis) + self.weight2
|
||||
|
||||
|
||||
def test_lazy_adam_compile():
|
||||
""" test lazy adam compile """
|
||||
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
|
||||
label = Tensor(np.zeros([1, 10]).astype(np.float32))
|
||||
net = Net()
|
||||
net.set_train()
|
||||
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
optimizer = LazyAdam(net.trainable_params(), learning_rate=0.1, weight_decay=0.9)
|
||||
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||
_executor.compile(train_network, inputs, label)
|
||||
|
||||
|
||||
def test_spares_lazy_adam_compile():
|
||||
""" test sparse adam compile """
|
||||
indices = Tensor(np.array([0, 1]).astype(np.int32))
|
||||
label = Tensor(np.zeros([2, 1, 2]).astype(np.float32))
|
||||
net = NetWithSparseGatherV2()
|
||||
net.set_train()
|
||||
|
||||
optimizer = LazyAdam(net.trainable_params(), learning_rate=0.1)
|
||||
train_network = TrainOneStepCell(net, optimizer)
|
||||
_executor.compile(train_network, indices, label)
|
||||
|
||||
|
||||
def test_lazy_adam_error():
|
||||
net = Net()
|
||||
with pytest.raises(ValueError):
|
||||
LazyAdam(net.get_parameters(), learning_rate=-0.1)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
LazyAdam(net.get_parameters(), learning_rate=0.1, beta1=2)
|
Loading…
Reference in New Issue