forked from mindspore-Ecosystem/mindspore
!3217 rename operator of sparse optimizer
Merge pull request !3217 from wangnan39/rename_the_operator_of_sparse_optimizer
This commit is contained in:
commit
5d42d00161
|
@ -40,7 +40,7 @@ class SparseApplyAdamCPUKernel : public CPUKernel {
|
|||
bool use_nesterov_{false};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(SparseApplyAdam,
|
||||
MS_REG_CPU_KERNEL(FusedSparseAdam,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
|
|
|
@ -42,19 +42,7 @@ class SparseApplyFtrlCPUKernel : public CPUKernel {
|
|||
float lr_power_{0};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(SparseApplyFtrl,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
SparseApplyFtrlCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(SparseApplyFtrlNoReturn,
|
||||
MS_REG_CPU_KERNEL(FusedSparseFtrl,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
|
|
|
@ -40,7 +40,7 @@ class SparseApplyLazyAdamCPUKernel : public CPUKernel {
|
|||
bool use_nesterov_{false};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(SparseApplyLazyAdam,
|
||||
MS_REG_CPU_KERNEL(FusedSparseLazyAdam,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
|
|
|
@ -39,20 +39,7 @@ class SparseApplyProximalAdagradCPUKernel : public CPUKernel {
|
|||
size_t var_outer_dim_size_{1};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(SparseApplyProximalAdagrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
SparseApplyProximalAdagradCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(SparseApplyProximalAdagradNoReturn,
|
||||
MS_REG_CPU_KERNEL(FusedSparseProximalAdagrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
|
|
|
@ -299,7 +299,7 @@ class Adam(Optimizer):
|
|||
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.opt = P.Adam(use_locking, use_nesterov)
|
||||
self.sparse_opt = P.SparseApplyAdam(use_locking, use_nesterov)
|
||||
self.sparse_opt = P.FusedSparseAdam(use_locking, use_nesterov)
|
||||
|
||||
def construct(self, gradients):
|
||||
params = self.parameters
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
from mindspore.ops import functional as F, composite as C, operations as P
|
||||
from mindspore.common import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from .optimizer import Optimizer, _apply_decay, _grad_scale
|
||||
|
@ -159,7 +158,7 @@ class FTRL(Optimizer):
|
|||
self.decay_tf = tuple((lambda: True)() for x in self.parameters)
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.opt = P.ApplyFtrl(use_locking=use_locking)
|
||||
self.sparse_opt = inner.SparseApplyFtrlNoReturn(learning_rate, l1, l2, lr_power, use_locking=use_locking)
|
||||
self.sparse_opt = P.FusedSparseFtrl(learning_rate, l1, l2, lr_power, use_locking=use_locking)
|
||||
|
||||
def construct(self, grads):
|
||||
params = self.parameters
|
||||
|
|
|
@ -182,7 +182,7 @@ class LazyAdam(Optimizer):
|
|||
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.opt = P.Adam(use_locking, use_nesterov)
|
||||
self.sparse_opt = P.SparseApplyLazyAdam(use_locking, use_nesterov)
|
||||
self.sparse_opt = P.FusedSparseLazyAdam(use_locking, use_nesterov)
|
||||
|
||||
def construct(self, gradients):
|
||||
gradients = self.decay_weight(gradients)
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
from mindspore.ops import functional as F, composite as C, operations as P
|
||||
from mindspore.common import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from .optimizer import Optimizer
|
||||
|
@ -101,7 +100,7 @@ class ProximalAdagrad(Optimizer):
|
|||
self.weight_decay = weight_decay
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.opt = P.ApplyProximalAdagrad(use_locking=use_locking)
|
||||
self.sparse_opt = inner.SparseApplyProximalAdagradNoReturn(use_locking=use_locking)
|
||||
self.sparse_opt = P.FusedSparseProximalAdagrad(use_locking=use_locking)
|
||||
|
||||
def construct(self, grads):
|
||||
params = self.parameters
|
||||
|
|
|
@ -56,7 +56,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
|
|||
|
||||
from .random_ops import (RandomChoiceWithMask, Normal, Gamma, Poisson, UniformInt, UniformReal,
|
||||
RandomCategorical, Laplace)
|
||||
from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, ApplyMomentum, BatchNorm,
|
||||
from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, ApplyMomentum, BatchNorm,
|
||||
BiasAdd, Conv2D,
|
||||
DepthwiseConv2dNative,
|
||||
DropoutDoMask, DropoutGrad, Dropout,
|
||||
|
@ -74,6 +74,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl
|
|||
SparseSoftmaxCrossEntropyWithLogits, Tanh,
|
||||
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl,
|
||||
ApplyProximalAdagrad, SparseApplyProximalAdagrad, SparseApplyAdagradV2, SparseApplyFtrlV2,
|
||||
FusedSparseFtrl, FusedSparseProximalAdagrad,
|
||||
ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2,
|
||||
ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent,
|
||||
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK)
|
||||
|
@ -114,8 +115,8 @@ __all__ = [
|
|||
'MaxPool',
|
||||
'TopK',
|
||||
'Adam',
|
||||
'SparseApplyAdam',
|
||||
'SparseApplyLazyAdam',
|
||||
'FusedSparseAdam',
|
||||
'FusedSparseLazyAdam',
|
||||
'Softplus',
|
||||
'Softmax',
|
||||
'Softsign',
|
||||
|
@ -311,8 +312,10 @@ __all__ = [
|
|||
"SpaceToBatch",
|
||||
"SparseApplyFtrl",
|
||||
"SparseApplyFtrlV2",
|
||||
"FusedSparseFtrl",
|
||||
"ApplyProximalAdagrad",
|
||||
"SparseApplyProximalAdagrad",
|
||||
"FusedSparseProximalAdagrad",
|
||||
"ApplyAdaMax",
|
||||
"ApplyAdadelta",
|
||||
"ApplyAdagrad",
|
||||
|
|
|
@ -18,9 +18,6 @@
|
|||
from ..._checkparam import Rel
|
||||
from ..._checkparam import Validator as validator
|
||||
from ...common import dtype as mstype
|
||||
from ..._c_expression import signature_rw as sig_rw
|
||||
from ..._c_expression import signature_kind as sig_kind
|
||||
from ..._c_expression import signature_dtype as sig_dtype
|
||||
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
||||
|
||||
|
||||
|
@ -394,183 +391,6 @@ class Dequant(PrimitiveWithInfer):
|
|||
return mstype.float16
|
||||
|
||||
|
||||
class SparseApplyFtrlNoReturn(PrimitiveWithInfer):
|
||||
"""
|
||||
Update relevant entries according to the FTRL-proximal scheme.
|
||||
|
||||
Args:
|
||||
lr (float): The learning rate value, must be positive.
|
||||
l1 (float): l1 regularization strength, must be greater than or equal to zero.
|
||||
l2 (float): l2 regularization strength, must be greater than or equal to zero.
|
||||
lr_power (float): Learning rate power controls how the learning rate decreases during training,
|
||||
must be less than or equal to zero. Use fixed learning rate if `lr_power` is zero.
|
||||
use_locking (bool): Use locks for update operation if True . Default: False.
|
||||
|
||||
Inputs:
|
||||
- **var** (Parameter): The variable to be updated. The data type must be float32.
|
||||
- **accum** (Parameter): The accum to be updated, must be same type and shape as `var`.
|
||||
- **linear** (Parameter): The linear to be updated, must be same type and shape as `var`.
|
||||
- **grad** (Tensor): A tensor of the same type as `var`, for the gradient.
|
||||
- **indices** (Tensor): A vector of indices into the first dimension of `var` and `accum`. The shape
|
||||
of `indices` must be the same as `grad` in first dimension. The type must be int32.
|
||||
|
||||
Outputs:
|
||||
Tuple of 3 Tensor, this operator will update the input parameters directly, the outputs are useless.
|
||||
|
||||
- **var** (Tensor) - A Tensor with shape (1,).
|
||||
- **accum** (Tensor) - A Tensor with shape (1,).
|
||||
- **linear** (Tensor) - A Tensor with shape (1,).
|
||||
|
||||
Examples:
|
||||
>>> import mindspore
|
||||
>>> import mindspore.nn as nn
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Parameter
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.ops import operations as P
|
||||
>>> class SparseApplyFtrlNet(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(SparseApplyFtrlNet, self).__init__()
|
||||
>>> self.sparse_apply_ftrl = P.SparseApplyFtrlV2(lr=0.01, l1=0.0, l2=0.0, lr_power=-0.5)
|
||||
>>> self.var = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="var")
|
||||
>>> self.accum = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="accum")
|
||||
>>> self.linear = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="linear")
|
||||
>>>
|
||||
>>> def construct(self, grad, indices):
|
||||
>>> out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices)
|
||||
>>> return out
|
||||
>>>
|
||||
>>> net = SparseApplyFtrlNet()
|
||||
>>> grad = Tensor(np.random.rand(2, 1, 2).astype(np.float32))
|
||||
>>> indices = Tensor(np.array([0, 1]).astype(np.int32))
|
||||
>>> output = net(grad, indices)
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('linear', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1)
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, lr, l1, l2, lr_power, use_locking=False):
|
||||
self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'indices'],
|
||||
outputs=['output'])
|
||||
validator.check_value_type("lr", lr, [float], self.name)
|
||||
validator.check_value_type("l1", l1, [float], self.name)
|
||||
validator.check_value_type("l2", l2, [float], self.name)
|
||||
validator.check_value_type("lr_power", lr_power, [float], self.name)
|
||||
self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name)
|
||||
self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name)
|
||||
self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name)
|
||||
self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name)
|
||||
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||
self.add_prim_attr('primitive_target', 'CPU')
|
||||
|
||||
def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape):
|
||||
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
|
||||
validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
|
||||
if len(var_shape) > 1:
|
||||
validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
|
||||
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
|
||||
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
|
||||
return [1], [1], [1]
|
||||
|
||||
def infer_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype):
|
||||
args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype,
|
||||
"linear_dtype": linear_dtype, "grad_dtype": grad_dtype}
|
||||
validator.check_tensor_type_same(args, [mstype.float32], self.name)
|
||||
validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name)
|
||||
return var_dtype, accum_dtype, linear_dtype
|
||||
|
||||
|
||||
class SparseApplyProximalAdagradNoReturn(PrimitiveWithInfer):
|
||||
r"""
|
||||
Updates relevant entries according to the proximal adagrad algorithm.
|
||||
|
||||
.. math::
|
||||
accum += grad * grad
|
||||
.. math::
|
||||
\text{prox_v} = var - lr * grad * \frac{1}{\sqrt{accum}}
|
||||
.. math::
|
||||
var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0)
|
||||
|
||||
Args:
|
||||
use_locking (bool): If True, updating of the var and accum tensors will be protected. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **var** (Parameter) - Variable tensor to be updated. The data type must be float32.
|
||||
- **accum** (Parameter) - Variable tensor to be updated. Has the same dtype as `var`.
|
||||
- **lr** (Tensor): The learning rate value. The data type must be float32.
|
||||
- **l1** (Tensor): l1 regularization strength. The data type must be float32.
|
||||
- **l2** (Tensor): l2 regularization strength. The data type must be float32.
|
||||
- **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. The data type must be float32.
|
||||
- **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. The data type
|
||||
must be int32.
|
||||
|
||||
Outputs:
|
||||
Tuple of 2 Tensor, this operator will update the input parameters directly, the outputs are useless.
|
||||
|
||||
- **var** (Tensor) - A Tensor with shape (1,).
|
||||
- **accum** (Tensor) - A Tensor with shape (1,).
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore.nn as nn
|
||||
>>> from mindspore import Tensor, Parameter
|
||||
>>> from mindspore.ops import operations as P
|
||||
>>> class Net(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(Net, self).__init__()
|
||||
>>> self.sparse_apply_proximal_adagrad = P.SparseApplyProximalAdagradV2()
|
||||
>>> self.var = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="var")
|
||||
>>> self.accum = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="accum")
|
||||
>>> self.lr = Tensor(0.01, mstype.float32)
|
||||
>>> self.l1 = Tensor(0.0, mstype.float32)
|
||||
>>> self.l2 = Tensor(0.0, mstype.float32)
|
||||
>>> def construct(self, grad, indices):
|
||||
>>> out = self.sparse_apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1,
|
||||
>>> self.l2, grad, indices)
|
||||
>>> return out
|
||||
>>> net = Net()
|
||||
>>> grad = Tensor(np.random.rand(2, 1, 2).astype(np.float32))
|
||||
>>> indices = Tensor(np.array([0, 1]).astype(np.int32))
|
||||
>>> output = net(grad, indices)
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1)
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=False):
|
||||
self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad', 'indices'],
|
||||
outputs=['output'])
|
||||
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||
self.add_prim_attr('primitive_target', 'CPU')
|
||||
|
||||
def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape):
|
||||
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
|
||||
return [1], [1]
|
||||
|
||||
def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype):
|
||||
args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
|
||||
validator.check_tensor_type_same(args, [mstype.float32], self.name)
|
||||
validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, [mstype.float32], self.name)
|
||||
validator.check_scalar_or_tensor_type_same({"l1": l1_dtype}, [mstype.float32], self.name)
|
||||
validator.check_scalar_or_tensor_type_same({"l2": l2_dtype}, [mstype.float32], self.name)
|
||||
valid_types = [mstype.int16, mstype.int32, mstype.int64,
|
||||
mstype.uint16, mstype.uint32, mstype.uint64]
|
||||
validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name)
|
||||
return var_dtype, accum_dtype
|
||||
|
||||
|
||||
class LinSpace(PrimitiveWithInfer):
|
||||
r"""
|
||||
Generates values in an interval. And return the corresponding interpolation accroding to assist.
|
||||
|
|
|
@ -2917,7 +2917,7 @@ class Adam(PrimitiveWithInfer):
|
|||
return var_dtype, m_dtype, v_dtype
|
||||
|
||||
|
||||
class SparseApplyAdam(PrimitiveWithInfer):
|
||||
class FusedSparseAdam(PrimitiveWithInfer):
|
||||
r"""
|
||||
Merge the duplicate value of the gradient and then updates parameters by Adaptive Moment Estimation (Adam)
|
||||
algorithm. This operator is used when the gradient is sparse.
|
||||
|
@ -2979,7 +2979,7 @@ class SparseApplyAdam(PrimitiveWithInfer):
|
|||
>>> class Net(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(Net, self).__init__()
|
||||
>>> self.sparse_apply_adam = P.SparseApplyAdam()
|
||||
>>> self.sparse_apply_adam = P.FusedSparseAdam()
|
||||
>>> self.var = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="var")
|
||||
>>> self.m = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="m")
|
||||
>>> self.v = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="v")
|
||||
|
@ -3025,7 +3025,6 @@ class SparseApplyAdam(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(inputs=['var', 'm', 'v', 'beta1_power', 'beta2_power', 'lr', 'beta1', 'beta2',
|
||||
'epsilon', 'grad', 'indices'],
|
||||
outputs=['var', 'm', 'v'])
|
||||
self.add_prim_attr('primitive_target', 'CPU')
|
||||
|
||||
def infer_shape(self, var_shape, m_shape, v_shape, beta1_power_shape, beta2_power_shape, lr_shape,
|
||||
beta1_shape, beta2_shape, epsilon_shape, grad_shape, indices_shape):
|
||||
|
@ -3051,7 +3050,7 @@ class SparseApplyAdam(PrimitiveWithInfer):
|
|||
return var_dtype, m_dtype, v_dtype
|
||||
|
||||
|
||||
class SparseApplyLazyAdam(PrimitiveWithInfer):
|
||||
class FusedSparseLazyAdam(PrimitiveWithInfer):
|
||||
r"""
|
||||
Merge the duplicate value of the gradient and then updates parameters by Adaptive Moment Estimation (Adam)
|
||||
algorithm. This operator is used when the gradient is sparse. The behavior is not equivalent to the
|
||||
|
@ -3114,7 +3113,7 @@ class SparseApplyLazyAdam(PrimitiveWithInfer):
|
|||
>>> class Net(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(Net, self).__init__()
|
||||
>>> self.sparse_apply_lazyadam = P.SparseApplyLazyAdam()
|
||||
>>> self.sparse_apply_lazyadam = P.FusedSparseLazyAdam()
|
||||
>>> self.var = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="var")
|
||||
>>> self.m = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="m")
|
||||
>>> self.v = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="v")
|
||||
|
@ -3160,7 +3159,6 @@ class SparseApplyLazyAdam(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(inputs=['var', 'm', 'v', 'beta1_power', 'beta2_power', 'lr', 'beta1', 'beta2',
|
||||
'epsilon', 'grad', 'indices'],
|
||||
outputs=['var', 'm', 'v'])
|
||||
self.add_prim_attr('primitive_target', 'CPU')
|
||||
|
||||
def infer_shape(self, var_shape, m_shape, v_shape, beta1_power_shape, beta2_power_shape, lr_shape,
|
||||
beta1_shape, beta2_shape, epsilon_shape, grad_shape, indices_shape):
|
||||
|
@ -3187,6 +3185,182 @@ class SparseApplyLazyAdam(PrimitiveWithInfer):
|
|||
return var_dtype, m_dtype, v_dtype
|
||||
|
||||
|
||||
class FusedSparseFtrl(PrimitiveWithInfer):
|
||||
"""
|
||||
Merge the duplicate value of the gradient and then update relevant entries according to the FTRL-proximal scheme.
|
||||
|
||||
Args:
|
||||
lr (float): The learning rate value, must be positive.
|
||||
l1 (float): l1 regularization strength, must be greater than or equal to zero.
|
||||
l2 (float): l2 regularization strength, must be greater than or equal to zero.
|
||||
lr_power (float): Learning rate power controls how the learning rate decreases during training,
|
||||
must be less than or equal to zero. Use fixed learning rate if `lr_power` is zero.
|
||||
use_locking (bool): Use locks for update operation if True . Default: False.
|
||||
|
||||
Inputs:
|
||||
- **var** (Parameter): The variable to be updated. The data type must be float32.
|
||||
- **accum** (Parameter): The accum to be updated, must be same type and shape as `var`.
|
||||
- **linear** (Parameter): The linear to be updated, must be same type and shape as `var`.
|
||||
- **grad** (Tensor): A tensor of the same type as `var`, for the gradient.
|
||||
- **indices** (Tensor): A vector of indices into the first dimension of `var` and `accum`. The shape
|
||||
of `indices` must be the same as `grad` in first dimension. The type must be int32.
|
||||
|
||||
Outputs:
|
||||
Tuple of 3 Tensor, this operator will update the input parameters directly, the outputs are useless.
|
||||
|
||||
- **var** (Tensor) - A Tensor with shape (1,).
|
||||
- **accum** (Tensor) - A Tensor with shape (1,).
|
||||
- **linear** (Tensor) - A Tensor with shape (1,).
|
||||
|
||||
Examples:
|
||||
>>> import mindspore
|
||||
>>> import mindspore.nn as nn
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Parameter
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.ops import operations as P
|
||||
>>> class SparseApplyFtrlNet(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(SparseApplyFtrlNet, self).__init__()
|
||||
>>> self.sparse_apply_ftrl = P.FusedSparseFtrl(lr=0.01, l1=0.0, l2=0.0, lr_power=-0.5)
|
||||
>>> self.var = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="var")
|
||||
>>> self.accum = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="accum")
|
||||
>>> self.linear = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="linear")
|
||||
>>>
|
||||
>>> def construct(self, grad, indices):
|
||||
>>> out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices)
|
||||
>>> return out
|
||||
>>>
|
||||
>>> net = SparseApplyFtrlNet()
|
||||
>>> grad = Tensor(np.random.rand(2, 1, 2).astype(np.float32))
|
||||
>>> indices = Tensor(np.array([0, 1]).astype(np.int32))
|
||||
>>> output = net(grad, indices)
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('linear', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1)
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, lr, l1, l2, lr_power, use_locking=False):
|
||||
self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'indices'],
|
||||
outputs=['output'])
|
||||
validator.check_value_type("lr", lr, [float], self.name)
|
||||
validator.check_value_type("l1", l1, [float], self.name)
|
||||
validator.check_value_type("l2", l2, [float], self.name)
|
||||
validator.check_value_type("lr_power", lr_power, [float], self.name)
|
||||
self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name)
|
||||
self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name)
|
||||
self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name)
|
||||
self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name)
|
||||
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||
|
||||
def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape):
|
||||
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
|
||||
validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
|
||||
if len(var_shape) > 1:
|
||||
validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
|
||||
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
|
||||
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
|
||||
return [1], [1], [1]
|
||||
|
||||
def infer_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype):
|
||||
args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype,
|
||||
"linear_dtype": linear_dtype, "grad_dtype": grad_dtype}
|
||||
validator.check_tensor_type_same(args, [mstype.float32], self.name)
|
||||
validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name)
|
||||
return var_dtype, accum_dtype, linear_dtype
|
||||
|
||||
|
||||
class FusedSparseProximalAdagrad(PrimitiveWithInfer):
|
||||
r"""
|
||||
Merge the duplicate value of the gradient and then Updates relevant entries according to the proximal adagrad
|
||||
algorithm.
|
||||
|
||||
.. math::
|
||||
accum += grad * grad
|
||||
.. math::
|
||||
\text{prox_v} = var - lr * grad * \frac{1}{\sqrt{accum}}
|
||||
.. math::
|
||||
var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0)
|
||||
|
||||
Args:
|
||||
use_locking (bool): If True, updating of the var and accum tensors will be protected. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **var** (Parameter) - Variable tensor to be updated. The data type must be float32.
|
||||
- **accum** (Parameter) - Variable tensor to be updated. Has the same dtype as `var`.
|
||||
- **lr** (Tensor): The learning rate value. The data type must be float32.
|
||||
- **l1** (Tensor): l1 regularization strength. The data type must be float32.
|
||||
- **l2** (Tensor): l2 regularization strength. The data type must be float32.
|
||||
- **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. The data type must be float32.
|
||||
- **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. The data type
|
||||
must be int32.
|
||||
|
||||
Outputs:
|
||||
Tuple of 2 Tensor, this operator will update the input parameters directly, the outputs are useless.
|
||||
|
||||
- **var** (Tensor) - A Tensor with shape (1,).
|
||||
- **accum** (Tensor) - A Tensor with shape (1,).
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore.nn as nn
|
||||
>>> from mindspore import Tensor, Parameter
|
||||
>>> from mindspore.ops import operations as P
|
||||
>>> class Net(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(Net, self).__init__()
|
||||
>>> self.sparse_apply_proximal_adagrad = P.FusedSparseProximalAdagrad()
|
||||
>>> self.var = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="var")
|
||||
>>> self.accum = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="accum")
|
||||
>>> self.lr = Tensor(0.01, mstype.float32)
|
||||
>>> self.l1 = Tensor(0.0, mstype.float32)
|
||||
>>> self.l2 = Tensor(0.0, mstype.float32)
|
||||
>>> def construct(self, grad, indices):
|
||||
>>> out = self.sparse_apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1,
|
||||
>>> self.l2, grad, indices)
|
||||
>>> return out
|
||||
>>> net = Net()
|
||||
>>> grad = Tensor(np.random.rand(2, 1, 2).astype(np.float32))
|
||||
>>> indices = Tensor(np.array([0, 1]).astype(np.int32))
|
||||
>>> output = net(grad, indices)
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1)
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=False):
|
||||
self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad', 'indices'],
|
||||
outputs=['output'])
|
||||
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||
|
||||
def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape):
|
||||
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
|
||||
return [1], [1]
|
||||
|
||||
def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype):
|
||||
args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
|
||||
validator.check_tensor_type_same(args, [mstype.float32], self.name)
|
||||
validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, [mstype.float32], self.name)
|
||||
validator.check_scalar_or_tensor_type_same({"l1": l1_dtype}, [mstype.float32], self.name)
|
||||
validator.check_scalar_or_tensor_type_same({"l2": l2_dtype}, [mstype.float32], self.name)
|
||||
valid_types = [mstype.int16, mstype.int32, mstype.int64,
|
||||
mstype.uint16, mstype.uint32, mstype.uint64]
|
||||
validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name)
|
||||
return var_dtype, accum_dtype
|
||||
|
||||
|
||||
class BinaryCrossEntropy(PrimitiveWithInfer):
|
||||
r"""
|
||||
Computes the Binary Cross Entropy between the target and the output.
|
||||
|
|
|
@ -33,7 +33,7 @@ epsilon = 1e-8
|
|||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.sparse_apply_adam = P.SparseApplyAdam()
|
||||
self.sparse_apply_adam = P.FusedSparseAdam()
|
||||
self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var")
|
||||
self.m = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="m")
|
||||
self.v = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="v")
|
||||
|
|
|
@ -26,7 +26,7 @@ import mindspore.common.dtype as mstype
|
|||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.001, l1=0.0, l2=0.0, lr_power=-0.5)
|
||||
self.sparse_apply_ftrl = P.FusedSparseFtrl(lr=0.001, l1=0.0, l2=0.0, lr_power=-0.5)
|
||||
self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var")
|
||||
self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="accum")
|
||||
self.linear = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="linear")
|
||||
|
|
|
@ -26,7 +26,7 @@ import mindspore.common.dtype as mstype
|
|||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.sparse_apply_proximal_adagrad = P.SparseApplyProximalAdagrad()
|
||||
self.sparse_apply_proximal_adagrad = P.FusedSparseProximalAdagrad()
|
||||
self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var")
|
||||
self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="accum")
|
||||
self.lr = 0.01
|
||||
|
|
Loading…
Reference in New Issue