forked from mindspore-Ecosystem/mindspore
add ops SparseApplyAdam and SparseApplyLazyAdam
This commit is contained in:
parent
72fd41786c
commit
de21dbdaef
|
@ -52,7 +52,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AssignAdd, AssignSub, Atan2
|
|||
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e,
|
||||
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh)
|
||||
from .random_ops import (RandomChoiceWithMask)
|
||||
from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
|
||||
from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, ApplyMomentum, BatchNorm,
|
||||
BiasAdd, Conv2D,
|
||||
DepthwiseConv2dNative,
|
||||
DropoutDoMask, DropoutGrad, Dropout,
|
||||
|
@ -101,6 +101,8 @@ __all__ = [
|
|||
'MaxPool',
|
||||
'TopK',
|
||||
'Adam',
|
||||
'SparseApplyAdam',
|
||||
'SparseApplyLazyAdam',
|
||||
'Softplus',
|
||||
'Softmax',
|
||||
'LogSoftmax',
|
||||
|
|
|
@ -2646,9 +2646,25 @@ class Adam(PrimitiveWithInfer):
|
|||
- **v** (Tensor) - The same shape and data type as `v`.
|
||||
|
||||
Examples:
|
||||
Please refer to the usage in nn.Adam.
|
||||
>>> import numpy as np
|
||||
>>> import mindspore.nn as nn
|
||||
>>> from mindspore import Tensor, Parameter
|
||||
>>> from mindspore.ops import operations as P
|
||||
>>> class Net(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(Net, self).__init__()
|
||||
>>> self.apply_adam = P.Adam()
|
||||
>>> self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var")
|
||||
>>> self.m = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="m")
|
||||
>>> self.v = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="v")
|
||||
>>> def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
|
||||
>>> out = self.apply_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2,
|
||||
>>> epsilon, grad)
|
||||
>>> return out
|
||||
>>> net = Net()
|
||||
>>> gradient = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
|
||||
>>> result = net(0.9, 0.999, 0.001, 0.9, 0.999, 1e-8, gradient)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=False, use_nesterov=False):
|
||||
validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||
|
@ -2672,6 +2688,260 @@ class Adam(PrimitiveWithInfer):
|
|||
return var_dtype, m_dtype, v_dtype
|
||||
|
||||
|
||||
class SparseApplyAdam(PrimitiveWithInfer):
|
||||
r"""
|
||||
Merge the duplicate value of the gradient and then updates parameters by Adaptive Moment Estimation (Adam)
|
||||
algorithm. This operator is used when the gradient is sparse.
|
||||
|
||||
The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
|
||||
|
||||
The updating formulas are as follows,
|
||||
|
||||
.. math::
|
||||
\begin{array}{ll} \\
|
||||
m = \beta_1 * m + (1 - \beta_1) * g \\
|
||||
v = \beta_2 * v + (1 - \beta_2) * g * g \\
|
||||
l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\
|
||||
w = w - l * \frac{m}{\sqrt{v} + \epsilon}
|
||||
\end{array}
|
||||
|
||||
:math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents
|
||||
`gradient`, :math:`l` represents scaling factor `lr`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`,
|
||||
:math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent `beta1_power` and
|
||||
`beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `var`, :math:`\epsilon` represents
|
||||
`epsilon`.
|
||||
|
||||
Args:
|
||||
use_locking (bool): Whether to enable a lock to protect updating variable tensors.
|
||||
If True, updating of the var, m, and v tensors will be protected by a lock.
|
||||
If False, the result is unpredictable. Default: False.
|
||||
use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
|
||||
If True, updates the gradients using NAG.
|
||||
If False, updates the gradients without using NAG. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **var** (Parameter) - Parameters to be updated.
|
||||
- **m** (Parameter) - The 1st moment vector in the updating formula. Has the same type as `var`.
|
||||
- **v** (Parameter) - The 2nd moment vector in the updating formula. Mean square gradients,
|
||||
has the same type as `var`.
|
||||
- **beta1_power** (float) - :math:`beta_1^t` in the updating formula.
|
||||
- **beta2_power** (float) - :math:`beta_2^t` in the updating formula.
|
||||
- **lr** (float) - :math:`l` in the updating formula.
|
||||
- **beta1** (float) - The exponential decay rate for the 1st moment estimates.
|
||||
- **beta2** (float) - The exponential decay rate for the 2nd moment estimates.
|
||||
- **epsilon** (float) - Term added to the denominator to improve numerical stability.
|
||||
- **gradient** (Tensor) - Gradient value.
|
||||
- **indices** (Tensor) - Gradient indices. With int32 data type.
|
||||
|
||||
Outputs:
|
||||
Tuple of 3 Tensor, the updated parameters.
|
||||
|
||||
- **var** (Tensor) - The same shape and data type as `var`.
|
||||
- **m** (Tensor) - The same shape and data type as `m`.
|
||||
- **v** (Tensor) - The same shape and data type as `v`.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore.nn as nn
|
||||
>>> from mindspore import Tensor, Parameter
|
||||
>>> from mindspore.ops import operations as P
|
||||
>>> import mindspore.common.dtype as mstype
|
||||
>>> class Net(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(Net, self).__init__()
|
||||
>>> self.sparse_apply_adam = P.SparseApplyAdam()
|
||||
>>> self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var")
|
||||
>>> self.m = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="m")
|
||||
>>> self.v = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="v")
|
||||
>>> def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, indices):
|
||||
>>> out = self.sparse_apply_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2,
|
||||
>>> epsilon, grad, indices)
|
||||
>>> return out
|
||||
>>> net = Net()
|
||||
>>> gradient = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
|
||||
>>> indices = Tensor([0, 1, 2], mstype.int32)
|
||||
>>> result = net(0.9, 0.999, 0.001, 0.9, 0.999, 1e-8, gradient, indices)
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('m', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('v', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('beta1_power', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('beta2_power', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('beta1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('beta2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('epsilon', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1)
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=False, use_nesterov=False):
|
||||
validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||
validator.check_value_type("use_nesterov", use_nesterov, [bool], self.name)
|
||||
self.init_prim_io_names(inputs=['var', 'm', 'v', 'beta1_power', 'beta2_power', 'lr', 'beta1', 'beta2',
|
||||
'epsilon', 'grad', 'indices'],
|
||||
outputs=['var', 'm', 'v'])
|
||||
|
||||
def infer_shape(self, var_shape, m_shape, v_shape, beta1_power_shape, beta2_power_shape, lr_shape,
|
||||
beta1_shape, beta2_shape, epsilon_shape, grad_shape, indices_shape):
|
||||
validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name)
|
||||
validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name)
|
||||
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
|
||||
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
|
||||
if len(var_shape) > 1 and grad_shape != indices_shape + var_shape[1:]:
|
||||
raise ValueError(f"For '{self.name}', the shape of updates should be [] or "
|
||||
f"grad_shape = indices_shape + var_shape[1:], but got var_shape: {var_shape}, "
|
||||
f"indices_shape: {indices_shape}, grad_shape: {grad_shape}.")
|
||||
return var_shape, m_shape, v_shape
|
||||
|
||||
def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype,
|
||||
beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype, indices_dtype):
|
||||
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype}
|
||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||
|
||||
args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype,
|
||||
"beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype}
|
||||
validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name, True)
|
||||
validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name)
|
||||
return var_dtype, m_dtype, v_dtype
|
||||
|
||||
|
||||
class SparseApplyLazyAdam(PrimitiveWithInfer):
|
||||
r"""
|
||||
Merge the duplicate value of the gradient and then updates parameters by Adaptive Moment Estimation (Adam)
|
||||
algorithm. This operator is used when the gradient is sparse. The behavior is not equivalent to the
|
||||
original Adam algorithm, as only the current indices parameters will be updated.
|
||||
|
||||
The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
|
||||
|
||||
The updating formulas are as follows,
|
||||
|
||||
.. math::
|
||||
\begin{array}{ll} \\
|
||||
m = \beta_1 * m + (1 - \beta_1) * g \\
|
||||
v = \beta_2 * v + (1 - \beta_2) * g * g \\
|
||||
l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\
|
||||
w = w - l * \frac{m}{\sqrt{v} + \epsilon}
|
||||
\end{array}
|
||||
|
||||
:math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents
|
||||
`gradient`, :math:`l` represents scaling factor `lr`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`,
|
||||
:math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent `beta1_power` and
|
||||
`beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `var`, :math:`\epsilon` represents
|
||||
`epsilon`.
|
||||
|
||||
Args:
|
||||
use_locking (bool): Whether to enable a lock to protect updating variable tensors.
|
||||
If True, updating of the var, m, and v tensors will be protected by a lock.
|
||||
If False, the result is unpredictable. Default: False.
|
||||
use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
|
||||
If True, updates the gradients using NAG.
|
||||
If False, updates the gradients without using NAG. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **var** (Parameter) - Weights to be updated.
|
||||
- **m** (Parameter) - The 1st moment vector in the updating formula. Has the same type as `var`.
|
||||
- **v** (Parameter) - The 2nd moment vector in the updating formula. Mean square gradients,
|
||||
has the same type as `var`.
|
||||
- **beta1_power** (float) - :math:`beta_1^t` in the updating formula.
|
||||
- **beta2_power** (float) - :math:`beta_2^t` in the updating formula.
|
||||
- **lr** (float) - :math:`l` in the updating formula.
|
||||
- **beta1** (float) - The exponential decay rate for the 1st moment estimates.
|
||||
- **beta2** (float) - The exponential decay rate for the 2nd moment estimates.
|
||||
- **epsilon** (float) - Term added to the denominator to improve numerical stability.
|
||||
- **gradient** (Tensor) - Gradient value.
|
||||
- **indices** (Tensor) - Gradient indices. With int32 data type.
|
||||
|
||||
Outputs:
|
||||
Tuple of 3 Tensor, the updated parameters.
|
||||
|
||||
- **var** (Tensor) - The same shape and data type as `var`.
|
||||
- **m** (Tensor) - The same shape and data type as `m`.
|
||||
- **v** (Tensor) - The same shape and data type as `v`.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore.nn as nn
|
||||
>>> from mindspore import Tensor, Parameter
|
||||
>>> from mindspore.ops import operations as P
|
||||
>>> import mindspore.common.dtype as mstype
|
||||
>>> class Net(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(Net, self).__init__()
|
||||
>>> self.sparse_apply_lazyadam = P.SparseApplyLazyAdam()
|
||||
>>> self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var")
|
||||
>>> self.m = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="m")
|
||||
>>> self.v = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="v")
|
||||
>>> def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, indices):
|
||||
>>> out = self.sparse_apply_lazyadam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1,
|
||||
>>> beta2, epsilon, grad, indices)
|
||||
>>> return out
|
||||
>>> net = Net()
|
||||
>>> gradient = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
|
||||
>>> indices = Tensor([0, 1, 2], mstype.int32)
|
||||
>>> result = net(0.9, 0.999, 0.001, 0.9, 0.999, 1e-8, gradient, indices)
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('m', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('v', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('beta1_power', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('beta2_power', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('beta1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('beta2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('epsilon', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1)
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=False, use_nesterov=False):
|
||||
validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||
validator.check_value_type("use_nesterov", use_nesterov, [bool], self.name)
|
||||
self.init_prim_io_names(inputs=['var', 'm', 'v', 'beta1_power', 'beta2_power', 'lr', 'beta1', 'beta2',
|
||||
'epsilon', 'grad', 'indices'],
|
||||
outputs=['var', 'm', 'v'])
|
||||
|
||||
def infer_shape(self, var_shape, m_shape, v_shape, beta1_power_shape, beta2_power_shape, lr_shape,
|
||||
beta1_shape, beta2_shape, epsilon_shape, grad_shape, indices_shape):
|
||||
validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name)
|
||||
validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name)
|
||||
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
|
||||
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
|
||||
if len(var_shape) > 1 and grad_shape != indices_shape + var_shape[1:]:
|
||||
raise ValueError(f"For '{self.name}', the shape of updates should be [] or "
|
||||
f"grad_shape = indices_shape + var_shape[1:], but got var_shape: {var_shape}, "
|
||||
f"indices_shape: {indices_shape}, grad_shape: {grad_shape}.")
|
||||
return var_shape, m_shape, v_shape
|
||||
|
||||
def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype,
|
||||
beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype, indices_dtype):
|
||||
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype}
|
||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||
|
||||
args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype,
|
||||
"beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype}
|
||||
validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name, True)
|
||||
|
||||
validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name)
|
||||
return var_dtype, m_dtype, v_dtype
|
||||
|
||||
|
||||
class BinaryCrossEntropy(PrimitiveWithInfer):
|
||||
r"""
|
||||
Computes the Binary Cross Entropy between the target and the output.
|
||||
|
|
|
@ -18,6 +18,7 @@ import pytest
|
|||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.nn.optim import AdamWeightDecay, AdamWeightDecayDynamicLR
|
||||
|
@ -108,3 +109,34 @@ def test_adam_mindspore_flatten():
|
|||
net = nn.Flatten()
|
||||
with pytest.raises(ValueError, match=r"Optimizer got an empty parameter list"):
|
||||
AdamWeightDecay(net.get_parameters())
|
||||
|
||||
|
||||
class TestSparseOps(nn.Cell):
|
||||
"""Define sparse operator"""
|
||||
def __init__(self, sparse_opt):
|
||||
super(TestSparseOps, self).__init__()
|
||||
self.sparse_apply_adam = sparse_opt
|
||||
self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var")
|
||||
self.m = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="m")
|
||||
self.v = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="v")
|
||||
|
||||
def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, indices):
|
||||
out = self.sparse_apply_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2, epsilon,
|
||||
grad, indices)
|
||||
return out
|
||||
|
||||
|
||||
def test_sparse_adam():
|
||||
"""test sparse operator"""
|
||||
gradient = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
|
||||
indices = Tensor([0, 1, 2], mstype.int32)
|
||||
net = TestSparseOps(P.SparseApplyAdam())
|
||||
_executor.compile(net, 0.9, 0.999, 0.001, 0.9, 0.999, 1e-8, gradient, indices)
|
||||
|
||||
|
||||
def test_sparse_lazy_adam():
|
||||
"""test sparse operator"""
|
||||
gradient = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
|
||||
indices = Tensor([0, 1, 2], mstype.int32)
|
||||
net = TestSparseOps(P.SparseApplyLazyAdam())
|
||||
_executor.compile(net, 0.9, 0.999, 0.001, 0.9, 0.999, 1e-8, gradient, indices)
|
||||
|
|
Loading…
Reference in New Issue