!2668 support weight decay for sparse optimizer

Merge pull request !2668 from wangnan39/support_weight_decay_for_sparse_optimizer
This commit is contained in:
mindspore-ci-bot 2020-06-28 22:24:45 +08:00 committed by Gitee
commit 9be17e2a59
10 changed files with 26 additions and 17 deletions

View File

@ -164,7 +164,7 @@ class Adam(Optimizer):
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
`sparse_grad` of `Parameter` being set. The sparse feature is under continuous development. The sparse
behavior is currently performed on the CPU, weight decay is not supported.
behavior is currently performed on the CPU.
Args:
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,

View File

@ -73,7 +73,7 @@ class FTRL(Optimizer):
Note:
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
`sparse_grad` of `Parameter` being set. The sparse feature is under continuous development. The sparse
behavior is currently performed on the CPU, weight decay is not supported.
behavior is currently performed on the CPU.
Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
@ -124,7 +124,7 @@ class FTRL(Optimizer):
linear = self.linear
lr = self.learning_rate
if self.weight_decay > 0.0:
grads = self.hyper_map(F.partial(_apply_decay, self.weight_decay), self.decay_tf, params, grads)
grads = self.map_(F.partial(_apply_decay, self.weight_decay), self.decay_tf, params, grads)
grads = self.scale_grad(grads)
success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, lr, self.l1, self.l2, self.lr_power),

View File

@ -94,8 +94,7 @@ class LazyAdam(Optimizer):
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
`sparse_grad` of `Parameter` being set. The sparse behavior, to be notice, is not equivalent to the
original Adam algorithm, as only the current indices parames will be updated. The sparse feature is under
continuous development. The sparse behavior is currently performed on the CPU, weight decay is
not supported.
continuous development. The sparse behavior is currently performed on the CPU.
Args:
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,

View File

@ -195,12 +195,12 @@ class Optimizer(Cell):
params = self.parameters
if self.is_group:
if self.exec_weight_decay:
gradients = self.hyper_map(F.partial(_apply_decay), self.weight_decay, self.decay_flags,
params, gradients)
gradients = self.map_(F.partial(_apply_decay), self.weight_decay, self.decay_flags,
params, gradients)
else:
if self.weight_decay > 0:
gradients = self.hyper_map(F.partial(_apply_decay, self.weight_decay), self.decay_flags,
params, gradients)
gradients = self.map_(F.partial(_apply_decay, self.weight_decay), self.decay_flags,
params, gradients)
return gradients
@ -479,10 +479,20 @@ class Optimizer(Cell):
op_add = P.AddN()
op_gather = P.GatherV2()
_apply_decay = C.MultitypeFuncGraph("apply_decay")
@_apply_decay.register("Number", "Bool", "Tensor", "Tuple")
def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient):
"""Get grad with weight_decay."""
if if_apply:
weight = op_gather(weight, gradient[0], 0)
return gradient[0], op_add((weight * weight_decay, gradient[1])), gradient[2]
return gradient
@_apply_decay.register("Number", "Bool", "Tensor", "Tensor")
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
"""Get grad with weight_decay."""

View File

@ -60,7 +60,7 @@ class ProximalAdagrad(Optimizer):
Note:
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
`sparse_grad` of `Parameter` being set as True. The sparse feature is under continuous development. The sparse
behavior is currently performed on the CPU, weight decay is not supported.
behavior is currently performed on the CPU.
Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`

View File

@ -107,7 +107,7 @@ def test_sparse_adam_compile():
net = NetWithSparseGatherV2()
net.set_train()
optimizer = Adam(net.trainable_params(), learning_rate=0.1, loss_scale=1024.0)
optimizer = Adam(net.trainable_params(), learning_rate=0.1, loss_scale=1024.0, weight_decay=0.9)
train_network = TrainOneStepCell(net, optimizer)
_executor.compile(train_network, indices, label)

View File

@ -71,6 +71,6 @@ def test_spares_ftrl_compile():
net = NetWithSparseGatherV2()
net.set_train()
optimizer = FTRL(net.trainable_params(), loss_scale=2.0)
optimizer = FTRL(net.trainable_params(), weight_decay=0.9, loss_scale=2.0)
train_network = TrainOneStepCell(net, optimizer)
_executor.compile(train_network, indices, label)

View File

@ -75,7 +75,7 @@ def test_spares_lazy_adam_compile():
net = NetWithSparseGatherV2()
net.set_train()
optimizer = LazyAdam(net.trainable_params(), learning_rate=0.1, loss_scale=2.0)
optimizer = LazyAdam(net.trainable_params(), learning_rate=0.1, weight_decay=0.9, loss_scale=2.0)
train_network = TrainOneStepCell(net, optimizer)
_executor.compile(train_network, indices, label)

View File

@ -57,7 +57,7 @@ def test_proximal_ada_grad():
net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = ProximalAdagrad(net.trainable_params())
optimizer = ProximalAdagrad(net.trainable_params(), weight_decay=0.9, loss_scale=1024.0)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
@ -70,6 +70,6 @@ def test_spares_proximal_ada_grad_compile():
net = NetWithSparseGatherV2()
net.set_train()
optimizer = ProximalAdagrad(net.trainable_params(), loss_scale=2.0)
optimizer = ProximalAdagrad(net.trainable_params(), weight_decay=0.9, loss_scale=1024.0)
train_network = TrainOneStepCell(net, optimizer)
_executor.compile(train_network, indices, label)

View File

@ -57,7 +57,7 @@ def test_rmsprop_compile():
def test_rmsprop_e():
net = Net()
with pytest.raises(ValueError):
RMSProp(net.get_parameters(), momentum=-0.1, learning_rate=0.1)
RMSProp(net.get_parameters(), momentum=-0.1, learning_rate=0.1, weight_decay=0.9)
with pytest.raises(TypeError):
RMSProp(net.get_parameters(), momentum=1, learning_rate=0.1)
RMSProp(net.get_parameters(), momentum=1, learning_rate=0.1, weight_decay=0.9)