forked from mindspore-Ecosystem/mindspore
!2668 support weight decay for sparse optimizer
Merge pull request !2668 from wangnan39/support_weight_decay_for_sparse_optimizer
This commit is contained in:
commit
9be17e2a59
|
@ -164,7 +164,7 @@ class Adam(Optimizer):
|
||||||
|
|
||||||
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
|
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
|
||||||
`sparse_grad` of `Parameter` being set. The sparse feature is under continuous development. The sparse
|
`sparse_grad` of `Parameter` being set. The sparse feature is under continuous development. The sparse
|
||||||
behavior is currently performed on the CPU, weight decay is not supported.
|
behavior is currently performed on the CPU.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
|
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
|
||||||
|
|
|
@ -73,7 +73,7 @@ class FTRL(Optimizer):
|
||||||
Note:
|
Note:
|
||||||
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
|
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
|
||||||
`sparse_grad` of `Parameter` being set. The sparse feature is under continuous development. The sparse
|
`sparse_grad` of `Parameter` being set. The sparse feature is under continuous development. The sparse
|
||||||
behavior is currently performed on the CPU, weight decay is not supported.
|
behavior is currently performed on the CPU.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
|
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
|
||||||
|
@ -124,7 +124,7 @@ class FTRL(Optimizer):
|
||||||
linear = self.linear
|
linear = self.linear
|
||||||
lr = self.learning_rate
|
lr = self.learning_rate
|
||||||
if self.weight_decay > 0.0:
|
if self.weight_decay > 0.0:
|
||||||
grads = self.hyper_map(F.partial(_apply_decay, self.weight_decay), self.decay_tf, params, grads)
|
grads = self.map_(F.partial(_apply_decay, self.weight_decay), self.decay_tf, params, grads)
|
||||||
|
|
||||||
grads = self.scale_grad(grads)
|
grads = self.scale_grad(grads)
|
||||||
success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, lr, self.l1, self.l2, self.lr_power),
|
success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, lr, self.l1, self.l2, self.lr_power),
|
||||||
|
|
|
@ -94,8 +94,7 @@ class LazyAdam(Optimizer):
|
||||||
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
|
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
|
||||||
`sparse_grad` of `Parameter` being set. The sparse behavior, to be notice, is not equivalent to the
|
`sparse_grad` of `Parameter` being set. The sparse behavior, to be notice, is not equivalent to the
|
||||||
original Adam algorithm, as only the current indices parames will be updated. The sparse feature is under
|
original Adam algorithm, as only the current indices parames will be updated. The sparse feature is under
|
||||||
continuous development. The sparse behavior is currently performed on the CPU, weight decay is
|
continuous development. The sparse behavior is currently performed on the CPU.
|
||||||
not supported.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
|
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
|
||||||
|
|
|
@ -195,12 +195,12 @@ class Optimizer(Cell):
|
||||||
params = self.parameters
|
params = self.parameters
|
||||||
if self.is_group:
|
if self.is_group:
|
||||||
if self.exec_weight_decay:
|
if self.exec_weight_decay:
|
||||||
gradients = self.hyper_map(F.partial(_apply_decay), self.weight_decay, self.decay_flags,
|
gradients = self.map_(F.partial(_apply_decay), self.weight_decay, self.decay_flags,
|
||||||
params, gradients)
|
params, gradients)
|
||||||
else:
|
else:
|
||||||
if self.weight_decay > 0:
|
if self.weight_decay > 0:
|
||||||
gradients = self.hyper_map(F.partial(_apply_decay, self.weight_decay), self.decay_flags,
|
gradients = self.map_(F.partial(_apply_decay, self.weight_decay), self.decay_flags,
|
||||||
params, gradients)
|
params, gradients)
|
||||||
|
|
||||||
return gradients
|
return gradients
|
||||||
|
|
||||||
|
@ -479,10 +479,20 @@ class Optimizer(Cell):
|
||||||
|
|
||||||
|
|
||||||
op_add = P.AddN()
|
op_add = P.AddN()
|
||||||
|
op_gather = P.GatherV2()
|
||||||
|
|
||||||
_apply_decay = C.MultitypeFuncGraph("apply_decay")
|
_apply_decay = C.MultitypeFuncGraph("apply_decay")
|
||||||
|
|
||||||
|
|
||||||
|
@_apply_decay.register("Number", "Bool", "Tensor", "Tuple")
|
||||||
|
def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient):
|
||||||
|
"""Get grad with weight_decay."""
|
||||||
|
if if_apply:
|
||||||
|
weight = op_gather(weight, gradient[0], 0)
|
||||||
|
return gradient[0], op_add((weight * weight_decay, gradient[1])), gradient[2]
|
||||||
|
return gradient
|
||||||
|
|
||||||
|
|
||||||
@_apply_decay.register("Number", "Bool", "Tensor", "Tensor")
|
@_apply_decay.register("Number", "Bool", "Tensor", "Tensor")
|
||||||
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
|
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
|
||||||
"""Get grad with weight_decay."""
|
"""Get grad with weight_decay."""
|
||||||
|
|
|
@ -60,7 +60,7 @@ class ProximalAdagrad(Optimizer):
|
||||||
Note:
|
Note:
|
||||||
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
|
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
|
||||||
`sparse_grad` of `Parameter` being set as True. The sparse feature is under continuous development. The sparse
|
`sparse_grad` of `Parameter` being set as True. The sparse feature is under continuous development. The sparse
|
||||||
behavior is currently performed on the CPU, weight decay is not supported.
|
behavior is currently performed on the CPU.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
|
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
|
||||||
|
|
|
@ -107,7 +107,7 @@ def test_sparse_adam_compile():
|
||||||
net = NetWithSparseGatherV2()
|
net = NetWithSparseGatherV2()
|
||||||
net.set_train()
|
net.set_train()
|
||||||
|
|
||||||
optimizer = Adam(net.trainable_params(), learning_rate=0.1, loss_scale=1024.0)
|
optimizer = Adam(net.trainable_params(), learning_rate=0.1, loss_scale=1024.0, weight_decay=0.9)
|
||||||
train_network = TrainOneStepCell(net, optimizer)
|
train_network = TrainOneStepCell(net, optimizer)
|
||||||
_executor.compile(train_network, indices, label)
|
_executor.compile(train_network, indices, label)
|
||||||
|
|
||||||
|
|
|
@ -71,6 +71,6 @@ def test_spares_ftrl_compile():
|
||||||
net = NetWithSparseGatherV2()
|
net = NetWithSparseGatherV2()
|
||||||
net.set_train()
|
net.set_train()
|
||||||
|
|
||||||
optimizer = FTRL(net.trainable_params(), loss_scale=2.0)
|
optimizer = FTRL(net.trainable_params(), weight_decay=0.9, loss_scale=2.0)
|
||||||
train_network = TrainOneStepCell(net, optimizer)
|
train_network = TrainOneStepCell(net, optimizer)
|
||||||
_executor.compile(train_network, indices, label)
|
_executor.compile(train_network, indices, label)
|
||||||
|
|
|
@ -75,7 +75,7 @@ def test_spares_lazy_adam_compile():
|
||||||
net = NetWithSparseGatherV2()
|
net = NetWithSparseGatherV2()
|
||||||
net.set_train()
|
net.set_train()
|
||||||
|
|
||||||
optimizer = LazyAdam(net.trainable_params(), learning_rate=0.1, loss_scale=2.0)
|
optimizer = LazyAdam(net.trainable_params(), learning_rate=0.1, weight_decay=0.9, loss_scale=2.0)
|
||||||
train_network = TrainOneStepCell(net, optimizer)
|
train_network = TrainOneStepCell(net, optimizer)
|
||||||
_executor.compile(train_network, indices, label)
|
_executor.compile(train_network, indices, label)
|
||||||
|
|
||||||
|
|
|
@ -57,7 +57,7 @@ def test_proximal_ada_grad():
|
||||||
net = Net()
|
net = Net()
|
||||||
net.set_train()
|
net.set_train()
|
||||||
loss = nn.SoftmaxCrossEntropyWithLogits()
|
loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||||
optimizer = ProximalAdagrad(net.trainable_params())
|
optimizer = ProximalAdagrad(net.trainable_params(), weight_decay=0.9, loss_scale=1024.0)
|
||||||
net_with_loss = WithLossCell(net, loss)
|
net_with_loss = WithLossCell(net, loss)
|
||||||
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||||
_executor.compile(train_network, inputs, label)
|
_executor.compile(train_network, inputs, label)
|
||||||
|
@ -70,6 +70,6 @@ def test_spares_proximal_ada_grad_compile():
|
||||||
net = NetWithSparseGatherV2()
|
net = NetWithSparseGatherV2()
|
||||||
net.set_train()
|
net.set_train()
|
||||||
|
|
||||||
optimizer = ProximalAdagrad(net.trainable_params(), loss_scale=2.0)
|
optimizer = ProximalAdagrad(net.trainable_params(), weight_decay=0.9, loss_scale=1024.0)
|
||||||
train_network = TrainOneStepCell(net, optimizer)
|
train_network = TrainOneStepCell(net, optimizer)
|
||||||
_executor.compile(train_network, indices, label)
|
_executor.compile(train_network, indices, label)
|
||||||
|
|
|
@ -57,7 +57,7 @@ def test_rmsprop_compile():
|
||||||
def test_rmsprop_e():
|
def test_rmsprop_e():
|
||||||
net = Net()
|
net = Net()
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
RMSProp(net.get_parameters(), momentum=-0.1, learning_rate=0.1)
|
RMSProp(net.get_parameters(), momentum=-0.1, learning_rate=0.1, weight_decay=0.9)
|
||||||
|
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
RMSProp(net.get_parameters(), momentum=1, learning_rate=0.1)
|
RMSProp(net.get_parameters(), momentum=1, learning_rate=0.1, weight_decay=0.9)
|
||||||
|
|
Loading…
Reference in New Issue