forked from mindspore-Ecosystem/mindspore
!2069 Fix input value check for Momentum, SparseApplyFtrl and SparseApplyAdagrad.
Merge pull request !2069 from liuxiao/fix-input-value-check-for-SparseApplyFtrl-and-SparseApplyAdaGrad
This commit is contained in:
commit
cef901fe0d
|
@ -74,6 +74,7 @@ static std::map<string, string> tbe_func_adapter_map = {
|
|||
{"apply_adadelta", "apply_adadelta_d"},
|
||||
{"apply_adagrad", "apply_adagrad_d"},
|
||||
{"apply_adagrad_v2", "apply_adagradv2_d"},
|
||||
{"sparse_apply_adagrad", "sparse_apply_adagrad_d"},
|
||||
{"transpose", "transpose_d"},
|
||||
{"fill", "fill_d"},
|
||||
{"unsorted_segment_sum", "unsorted_segment_sum_d"},
|
||||
|
|
|
@ -18,6 +18,7 @@ from mindspore.common.parameter import Parameter
|
|||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore._checkparam import check_bool
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from .optimizer import Optimizer
|
||||
|
||||
momentum_opt = C.MultitypeFuncGraph("momentum_opt")
|
||||
|
@ -65,16 +66,18 @@ class Momentum(Optimizer):
|
|||
in the value of 'order_params' but not in any group will use default learning rate and default weight
|
||||
decay.
|
||||
|
||||
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
|
||||
Iterable or a Tensor and the dims of the Tensor is 1,
|
||||
use dynamic learning rate, then the i-th step will
|
||||
take the i-th value as the learning rate.
|
||||
When the learning_rate is float or learning_rate is a Tensor
|
||||
but the dims of the Tensor is 0, use fixed learning rate.
|
||||
Other cases are not supported.
|
||||
learning_rate (Union[int, float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
|
||||
Iterable or a Tensor and the dims of the Tensor is 1,
|
||||
use dynamic learning rate, then the i-th step will
|
||||
take the i-th value as the learning rate.
|
||||
When the learning_rate is float or learning_rate is a
|
||||
Tensor but the dims of the Tensor is 0, use fixed learning
|
||||
rate. Other cases are not supported. It should be equal to
|
||||
or greater than 0.0.
|
||||
momentum (float): Hyperparameter of type float, means momentum for the moving average.
|
||||
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
|
||||
loss_scale (float): A floating point value for the loss scale. Default: 1.0.
|
||||
It should be at least 0.0.
|
||||
weight_decay (int, float): Weight decay (L2 penalty). It should be equal to or greater than 0.0. Default: 0.0.
|
||||
loss_scale (int, float): A floating point value for the loss scale. It should be greater than 0.0. Default: 1.0.
|
||||
use_nesterov (bool): Enable Nesterov momentum. Default: False.
|
||||
|
||||
Inputs:
|
||||
|
@ -109,6 +112,7 @@ class Momentum(Optimizer):
|
|||
"""
|
||||
def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0, use_nesterov=False):
|
||||
super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale)
|
||||
validator.check_value_type("momentum", momentum, [float], self.cls_name)
|
||||
if isinstance(momentum, float) and momentum < 0.0:
|
||||
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
|
||||
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
|
||||
|
|
|
@ -13,15 +13,15 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""SparseApplyAdagrad op"""
|
||||
"""SparseApplyAdagradD op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
sparse_apply_adagrad_op_info = TBERegOp("SparseApplyAdagrad") \
|
||||
sparse_apply_adagrad_d_op_info = TBERegOp("SparseApplyAdagrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("sparse_apply_adagrad.so") \
|
||||
.binfile_name("sparse_apply_adagrad_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("sparse_apply_adagrad") \
|
||||
.kernel_name("sparse_apply_adagrad_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("lr", "required", "float", "all") \
|
||||
.attr("update_slots", "optional", "bool", "all") \
|
||||
|
@ -31,14 +31,17 @@ sparse_apply_adagrad_op_info = TBERegOp("SparseApplyAdagrad") \
|
|||
.input(2, "grad", False, "required", "all") \
|
||||
.input(3, "indices", False, "required", "all") \
|
||||
.output(0, "var", False, "required", "all") \
|
||||
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW) \
|
||||
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.F32_NHWC) \
|
||||
.output(1, "accum", False, "required", "all") \
|
||||
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW,
|
||||
DataType.F32_NCHW, DataType.F32_NCHW) \
|
||||
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC,
|
||||
DataType.F32_NHWC, DataType.F32_NHWC) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I32_Default,
|
||||
DataType.F32_Default) \
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(sparse_apply_adagrad_op_info)
|
||||
@op_info_register(sparse_apply_adagrad_d_op_info)
|
||||
def _sparse_apply_adagrad_tbe():
|
||||
"""SparseApplyAdagrad TBE register"""
|
||||
"""SparseApplyAdagradD TBE register"""
|
||||
return
|
||||
|
|
|
@ -3184,9 +3184,9 @@ class ApplyAdadelta(PrimitiveWithInfer):
|
|||
.. math::
|
||||
accum = \rho * accum + (1 - \rho) * grad^2
|
||||
.. math::
|
||||
update = \sqrt{accum_update + \esilon} * \rsqrt{accum + \epsilon} * grad
|
||||
\text{update} = \sqrt{\text{accum_update} + \epsilon} * \frac{grad}{\sqrt{accum + \epsilon}}
|
||||
.. math::
|
||||
accum_update = \rho * accum_update + (1 - \rho) * update^2
|
||||
\text{accum_update} = \rho * \text{accum_update} + (1 - \rho) * update^2
|
||||
.. math::
|
||||
var -= lr * update
|
||||
|
||||
|
@ -3377,11 +3377,12 @@ class SparseApplyAdagrad(PrimitiveWithInfer):
|
|||
|
||||
Args:
|
||||
lr (float): Learning rate.
|
||||
update_slots (bool): If `True`, `accum` will be updated. Default: True.
|
||||
use_locking (bool): If True, updating of the var and accum tensors will be protected. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **var** (Tensor) - Variable to be updated. The type must be float32.
|
||||
- **accum** (Tensor) - Accum to be updated. The shape must be the same as `var`'s shape,
|
||||
- **var** (Parameter) - Variable to be updated. The type must be float32.
|
||||
- **accum** (Parameter) - Accum to be updated. The shape must be the same as `var`'s shape,
|
||||
the type must be float32.
|
||||
- **grad** (Tensor) - Gradient. The shape must be the same as `var`'s shape
|
||||
except first dimension, the type must be float32.
|
||||
|
@ -3389,21 +3390,45 @@ class SparseApplyAdagrad(PrimitiveWithInfer):
|
|||
The shape of `indices` must be the same as `grad` in first dimension, the type must be int32.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape and type as `var`.
|
||||
Tuple of 2 Tensor, the updated parameters.
|
||||
|
||||
- **var** (Tensor) - The same shape and data type as `var`.
|
||||
- **accum** (Tensor) - The same shape and data type as `accum`.
|
||||
|
||||
Examples:
|
||||
>>> var = Tensor(np.random.random((3, 3)), mindspore.float32)
|
||||
>>> accum = Tensor(np.random.random((3, 3)), mindspore.float32)
|
||||
>>> grad = Tensor(np.random.random((3, 3)), mindspore.float32)
|
||||
>>> indices = Tensor(np.ones((3,), np.int32))
|
||||
>>> sparse_apply_ada_grad = P.SparseApplyAdagrad(0.5)
|
||||
>>> sparse_apply_ada_grad(var, accum, grad, indices)
|
||||
>>> import numpy as np
|
||||
>>> import mindspore.nn as nn
|
||||
>>> from mindspore import Tensor, Parameter
|
||||
>>> from mindspore.ops import operations as P
|
||||
>>> import mindspore.common.dtype as mstype
|
||||
>>> class Net(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(Net, self).__init__()
|
||||
>>> self.sparse_apply_adagrad = P.SparseApplyAdagrad(lr=1e-8)
|
||||
>>> self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var")
|
||||
>>> self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="accum")
|
||||
>>> def construct(self, grad, indices):
|
||||
>>> out = self.sparse_apply_adagrad(self.var, self.accum, grad, indices)
|
||||
>>> return out
|
||||
>>> net = Net()
|
||||
>>> grad = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
|
||||
>>> indices = Tensor([0, 1, 2], mstype.int32)
|
||||
>>> result = net(grad, indices)
|
||||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1)
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, lr, use_locking=False):
|
||||
self.lr = validator.check_value_type("lr", lr, [float], self.name)
|
||||
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||
def __init__(self, lr, update_slots=True, use_locking=False):
|
||||
validator.check_value_type("lr", lr, [float], self.name)
|
||||
validator.check_number_range("lr", lr, float("-inf"), float("inf"), Rel.INC_NEITHER, self.name)
|
||||
validator.check_value_type("update_slots", update_slots, [bool], self.name)
|
||||
validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||
|
||||
def infer_shape(self, var_shape, accum_shape, grad_shape, indices_shape):
|
||||
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
|
||||
|
@ -3757,8 +3782,8 @@ class SparseApplyFtrl(PrimitiveWithInfer):
|
|||
validator.check_value_type("l2", l2, [float], self.name)
|
||||
validator.check_value_type("lr_power", lr_power, [float], self.name)
|
||||
self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name)
|
||||
self.l1 = validator.check_number("l1", l1, 0.0, Rel.GE, self.name)
|
||||
self.l2 = validator.check_number("l2", l2, 0.0, Rel.GE, self.name)
|
||||
self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name)
|
||||
self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name)
|
||||
self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name)
|
||||
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||
|
||||
|
|
|
@ -82,7 +82,7 @@ def test_lenet_nccl():
|
|||
net.set_train()
|
||||
|
||||
learning_rate = multisteplr(epoch, 2)
|
||||
momentum = Tensor(np.array([0.9]).astype(np.float32))
|
||||
momentum = 0.9
|
||||
mom_optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
|
||||
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
||||
net_with_criterion = WithLossCell(net, criterion)
|
||||
|
|
|
@ -25,7 +25,6 @@ import mindspore.dataset.transforms.vision.c_transforms as CV
|
|||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.dataset.transforms.vision import Inter
|
||||
from mindspore.model_zoo.lenet import LeNet5
|
||||
from mindspore.nn import Dense, TrainOneStepCell, WithLossCell
|
||||
|
@ -84,7 +83,7 @@ def multisteplr(total_steps, gap, base_lr=0.9, gamma=0.1, dtype=mstype.float32):
|
|||
def test_train_lenet():
|
||||
epoch = 100
|
||||
net = LeNet()
|
||||
momentum = initializer(Tensor(np.array([0.9]).astype(np.float32)), [1])
|
||||
momentum = 0.9
|
||||
learning_rate = multisteplr(epoch, 30)
|
||||
|
||||
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
|
||||
|
|
|
@ -49,7 +49,7 @@ def test_momentum():
|
|||
epoch = 3
|
||||
net = NetMomentum()
|
||||
learning_rate = initializer(Tensor(np.array([0.01]).astype(np.float32)), [1])
|
||||
momentum = initializer(Tensor(np.array([0.9]).astype(np.float32)), [1])
|
||||
momentum = 0.9
|
||||
|
||||
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
|
||||
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
||||
|
|
|
@ -351,6 +351,17 @@ class ApplyAdagradV2Net(nn.Cell):
|
|||
return out
|
||||
|
||||
|
||||
class SparseApplyAdagradNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SparseApplyAdagradNet, self).__init__()
|
||||
self.sparse_apply_adagrad = P.SparseApplyAdagrad(lr=0.01)
|
||||
self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
|
||||
self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum")
|
||||
|
||||
def construct(self, grad, indices):
|
||||
out = self.sparse_apply_adagrad(self.var, self.accum, grad, indices)
|
||||
return out
|
||||
|
||||
class ApplyRMSNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ApplyRMSNet, self).__init__()
|
||||
|
@ -1181,8 +1192,8 @@ test_case_nn_ops = [
|
|||
'desc_inputs': [[1, 2, 3], [1, 2, 3], [1, 2, 3]],
|
||||
'desc_bprop': []}),
|
||||
('SparseApplyAdagrad', {
|
||||
'block': P.SparseApplyAdagrad(0.5),
|
||||
'desc_inputs': [[3, 3], [3, 3], [3, 3], Tensor(np.ones((3,), np.int32))],
|
||||
'block': SparseApplyAdagradNet(),
|
||||
'desc_inputs': [[3, 3], Tensor(np.ones((3,), np.int32))],
|
||||
'desc_bprop': [[3, 3], [3, 3]],
|
||||
'skip': ['backward']}),
|
||||
('SparseApplyFtrl', {
|
||||
|
@ -1332,13 +1343,6 @@ test_case_nn_ops = [
|
|||
Tensor([[-1.4, -0.7], [0.9, 0.7]], mstype.float16)],
|
||||
'desc_bprop': [],
|
||||
'skip': ['backward']}),
|
||||
('SparseApplyAdagrad', {
|
||||
'block': P.SparseApplyAdagrad(0.5),
|
||||
'desc_inputs': [Tensor([[0.7, 0.2], [0.1, 0.07]], mstype.float32),
|
||||
Tensor([[0.2, 0.2], [0.1, 0.4]], mstype.float32),
|
||||
Tensor([[0.5, 0.4], [0.6, 0.1]], mstype.float32), Tensor([1, 1], mstype.int32)],
|
||||
'desc_bprop': [Tensor([[0.7, 0.2], [0.1, 0.07]], mstype.float32)],
|
||||
'skip': ['backward']}),
|
||||
('DataFormatDimMap', {
|
||||
'block': P.DataFormatDimMap(),
|
||||
'desc_inputs': [Tensor([0, 1, 2, 3], mstype.int32)],
|
||||
|
|
Loading…
Reference in New Issue