forked from mindspore-Ecosystem/mindspore
!2668 support weight decay for sparse optimizer
Merge pull request !2668 from wangnan39/support_weight_decay_for_sparse_optimizer
This commit is contained in:
commit
9be17e2a59
|
@ -164,7 +164,7 @@ class Adam(Optimizer):
|
|||
|
||||
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
|
||||
`sparse_grad` of `Parameter` being set. The sparse feature is under continuous development. The sparse
|
||||
behavior is currently performed on the CPU, weight decay is not supported.
|
||||
behavior is currently performed on the CPU.
|
||||
|
||||
Args:
|
||||
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
|
||||
|
|
|
@ -73,7 +73,7 @@ class FTRL(Optimizer):
|
|||
Note:
|
||||
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
|
||||
`sparse_grad` of `Parameter` being set. The sparse feature is under continuous development. The sparse
|
||||
behavior is currently performed on the CPU, weight decay is not supported.
|
||||
behavior is currently performed on the CPU.
|
||||
|
||||
Args:
|
||||
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
|
||||
|
@ -124,7 +124,7 @@ class FTRL(Optimizer):
|
|||
linear = self.linear
|
||||
lr = self.learning_rate
|
||||
if self.weight_decay > 0.0:
|
||||
grads = self.hyper_map(F.partial(_apply_decay, self.weight_decay), self.decay_tf, params, grads)
|
||||
grads = self.map_(F.partial(_apply_decay, self.weight_decay), self.decay_tf, params, grads)
|
||||
|
||||
grads = self.scale_grad(grads)
|
||||
success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, lr, self.l1, self.l2, self.lr_power),
|
||||
|
|
|
@ -94,8 +94,7 @@ class LazyAdam(Optimizer):
|
|||
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
|
||||
`sparse_grad` of `Parameter` being set. The sparse behavior, to be notice, is not equivalent to the
|
||||
original Adam algorithm, as only the current indices parames will be updated. The sparse feature is under
|
||||
continuous development. The sparse behavior is currently performed on the CPU, weight decay is
|
||||
not supported.
|
||||
continuous development. The sparse behavior is currently performed on the CPU.
|
||||
|
||||
Args:
|
||||
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
|
||||
|
|
|
@ -195,12 +195,12 @@ class Optimizer(Cell):
|
|||
params = self.parameters
|
||||
if self.is_group:
|
||||
if self.exec_weight_decay:
|
||||
gradients = self.hyper_map(F.partial(_apply_decay), self.weight_decay, self.decay_flags,
|
||||
params, gradients)
|
||||
gradients = self.map_(F.partial(_apply_decay), self.weight_decay, self.decay_flags,
|
||||
params, gradients)
|
||||
else:
|
||||
if self.weight_decay > 0:
|
||||
gradients = self.hyper_map(F.partial(_apply_decay, self.weight_decay), self.decay_flags,
|
||||
params, gradients)
|
||||
gradients = self.map_(F.partial(_apply_decay, self.weight_decay), self.decay_flags,
|
||||
params, gradients)
|
||||
|
||||
return gradients
|
||||
|
||||
|
@ -479,10 +479,20 @@ class Optimizer(Cell):
|
|||
|
||||
|
||||
op_add = P.AddN()
|
||||
op_gather = P.GatherV2()
|
||||
|
||||
_apply_decay = C.MultitypeFuncGraph("apply_decay")
|
||||
|
||||
|
||||
@_apply_decay.register("Number", "Bool", "Tensor", "Tuple")
|
||||
def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient):
|
||||
"""Get grad with weight_decay."""
|
||||
if if_apply:
|
||||
weight = op_gather(weight, gradient[0], 0)
|
||||
return gradient[0], op_add((weight * weight_decay, gradient[1])), gradient[2]
|
||||
return gradient
|
||||
|
||||
|
||||
@_apply_decay.register("Number", "Bool", "Tensor", "Tensor")
|
||||
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
|
||||
"""Get grad with weight_decay."""
|
||||
|
|
|
@ -60,7 +60,7 @@ class ProximalAdagrad(Optimizer):
|
|||
Note:
|
||||
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
|
||||
`sparse_grad` of `Parameter` being set as True. The sparse feature is under continuous development. The sparse
|
||||
behavior is currently performed on the CPU, weight decay is not supported.
|
||||
behavior is currently performed on the CPU.
|
||||
|
||||
Args:
|
||||
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
|
||||
|
|
|
@ -107,7 +107,7 @@ def test_sparse_adam_compile():
|
|||
net = NetWithSparseGatherV2()
|
||||
net.set_train()
|
||||
|
||||
optimizer = Adam(net.trainable_params(), learning_rate=0.1, loss_scale=1024.0)
|
||||
optimizer = Adam(net.trainable_params(), learning_rate=0.1, loss_scale=1024.0, weight_decay=0.9)
|
||||
train_network = TrainOneStepCell(net, optimizer)
|
||||
_executor.compile(train_network, indices, label)
|
||||
|
||||
|
|
|
@ -71,6 +71,6 @@ def test_spares_ftrl_compile():
|
|||
net = NetWithSparseGatherV2()
|
||||
net.set_train()
|
||||
|
||||
optimizer = FTRL(net.trainable_params(), loss_scale=2.0)
|
||||
optimizer = FTRL(net.trainable_params(), weight_decay=0.9, loss_scale=2.0)
|
||||
train_network = TrainOneStepCell(net, optimizer)
|
||||
_executor.compile(train_network, indices, label)
|
||||
|
|
|
@ -75,7 +75,7 @@ def test_spares_lazy_adam_compile():
|
|||
net = NetWithSparseGatherV2()
|
||||
net.set_train()
|
||||
|
||||
optimizer = LazyAdam(net.trainable_params(), learning_rate=0.1, loss_scale=2.0)
|
||||
optimizer = LazyAdam(net.trainable_params(), learning_rate=0.1, weight_decay=0.9, loss_scale=2.0)
|
||||
train_network = TrainOneStepCell(net, optimizer)
|
||||
_executor.compile(train_network, indices, label)
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ def test_proximal_ada_grad():
|
|||
net = Net()
|
||||
net.set_train()
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
optimizer = ProximalAdagrad(net.trainable_params())
|
||||
optimizer = ProximalAdagrad(net.trainable_params(), weight_decay=0.9, loss_scale=1024.0)
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||
_executor.compile(train_network, inputs, label)
|
||||
|
@ -70,6 +70,6 @@ def test_spares_proximal_ada_grad_compile():
|
|||
net = NetWithSparseGatherV2()
|
||||
net.set_train()
|
||||
|
||||
optimizer = ProximalAdagrad(net.trainable_params(), loss_scale=2.0)
|
||||
optimizer = ProximalAdagrad(net.trainable_params(), weight_decay=0.9, loss_scale=1024.0)
|
||||
train_network = TrainOneStepCell(net, optimizer)
|
||||
_executor.compile(train_network, indices, label)
|
||||
|
|
|
@ -57,7 +57,7 @@ def test_rmsprop_compile():
|
|||
def test_rmsprop_e():
|
||||
net = Net()
|
||||
with pytest.raises(ValueError):
|
||||
RMSProp(net.get_parameters(), momentum=-0.1, learning_rate=0.1)
|
||||
RMSProp(net.get_parameters(), momentum=-0.1, learning_rate=0.1, weight_decay=0.9)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
RMSProp(net.get_parameters(), momentum=1, learning_rate=0.1)
|
||||
RMSProp(net.get_parameters(), momentum=1, learning_rate=0.1, weight_decay=0.9)
|
||||
|
|
Loading…
Reference in New Issue