diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index d0130ea553a..92cab56a052 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -157,7 +157,7 @@ class Adam(Optimizer): The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the `sparse_grad` of `Parameter` being set as True. The sparse feature is under continuous development. The sparse - behavior is currently performed on the CPU, weight decay and loss scale is not supported. + behavior is currently performed on the CPU, weight decay is not supported. Args: params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, diff --git a/mindspore/nn/optim/ftrl.py b/mindspore/nn/optim/ftrl.py index 78c38cba6ce..d1f49a3791d 100644 --- a/mindspore/nn/optim/ftrl.py +++ b/mindspore/nn/optim/ftrl.py @@ -73,7 +73,7 @@ class FTRL(Optimizer): Note: The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the `sparse_grad` of `Parameter` being set as True. The sparse feature is under continuous development. The sparse - behavior is currently performed on the CPU, weight decay and loss scale is not supported. + behavior is currently performed on the CPU, weight decay is not supported. Args: params (list[Parameter]): A list of parameter, which will be updated. The element in `params` diff --git a/mindspore/nn/optim/lazyadam.py b/mindspore/nn/optim/lazyadam.py index 83ba179c42e..d9df717b8a1 100644 --- a/mindspore/nn/optim/lazyadam.py +++ b/mindspore/nn/optim/lazyadam.py @@ -94,8 +94,7 @@ class LazyAdam(Optimizer): The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the `sparse_grad` of `Parameter` being set as True. The sparse behavior, to be notice, is not equivalent to the original Adam algorithm, as only the current indices parames will be updated. The sparse feature is under - continuous development. The sparse behavior is currently performed on the CPU, weight decay and loss scale is - not supported. + continuous development. The sparse behavior is currently performed on the CPU, weight decay is not supported. Args: params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, diff --git a/mindspore/nn/optim/proximal_ada_grad.py b/mindspore/nn/optim/proximal_ada_grad.py index 06b4535dbb3..380720404ad 100644 --- a/mindspore/nn/optim/proximal_ada_grad.py +++ b/mindspore/nn/optim/proximal_ada_grad.py @@ -22,9 +22,16 @@ from .optimizer import Optimizer _proximal_ada_grad_opt = C.MultitypeFuncGraph("proximal_ada_grad_opt") +@_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tuple", "Tensor", "Tensor") +def _tensor_run_opt_with_sparse(opt, sparse_opt, learning_rate, l1, l2, gradient, weight, accum): + """Apply sparse proximal_ada_grad optimizer to the weight parameter.""" + success = True + success = F.depend(success, sparse_opt(weight, accum, learning_rate, l1, l2, gradient[1], gradient[0])) + return success -@_proximal_ada_grad_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") -def _tensor_run_opt(opt, learning_rate, l1, l2, gradient, weight, accum): + +@_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") +def _tensor_run_opt(opt, sparse_opt, learning_rate, l1, l2, gradient, weight, accum): """Apply proximal_ada_grad optimizer to the weight parameter.""" success = True success = F.depend(success, opt(weight, accum, learning_rate, l1, l2, gradient)) @@ -50,6 +57,11 @@ class ProximalAdagrad(Optimizer): Refer to paper `Efficient Learning using Forward-Backward Splitting `_. + Note: + The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the + `sparse_grad` of `Parameter` being set as True. The sparse feature is under continuous development. The sparse + behavior is currently performed on the CPU, weight decay is not supported. + Args: params (list[Parameter]): A list of parameter, which will be updated. The element in `params` should be Parameter. @@ -87,6 +99,7 @@ class ProximalAdagrad(Optimizer): self.weight_decay = weight_decay self.hyper_map = C.HyperMap() self.opt = P.ApplyProximalAdagrad(use_locking=use_locking) + self.sparse_opt = P.SparseApplyProximalAdagrad(use_locking=use_locking) def construct(self, grads): params = self.parameters @@ -94,6 +107,6 @@ class ProximalAdagrad(Optimizer): grads = self.decay_weight(grads) grads = self.scale_grad(grads) lr = self.learning_rate - success = self.hyper_map(F.partial(_proximal_ada_grad_opt, self.opt, lr, self.l1, self.l2), - grads, params, accum) + success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, lr, self.l1, self.l2), + grads, params, accum) return success diff --git a/tests/ut/python/nn/optim/test_proximal_ada_grad.py b/tests/ut/python/nn/optim/test_proximal_ada_grad.py index 03bebb9cb28..52e418d39b5 100644 --- a/tests/ut/python/nn/optim/test_proximal_ada_grad.py +++ b/tests/ut/python/nn/optim/test_proximal_ada_grad.py @@ -36,6 +36,18 @@ class Net(nn.Cell): x = self.biasAdd(self.matmul(x, self.weight), self.bias) return x +class NetWithSparseGatherV2(nn.Cell): + """ NetWithSparseGatherV2 definition """ + def __init__(self): + super(NetWithSparseGatherV2, self).__init__() + self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1", sparse_grad=True) + self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="weight2") + self.axis = 0 + self.gather = P.SparseGatherV2() + + def construct(self, indices, label): + return self.gather(self.weight1, indices, self.axis) + self.weight2 + def test_proximal_ada_grad(): """ test_proximal_ada_grad """ @@ -48,3 +60,15 @@ def test_proximal_ada_grad(): net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepCell(net_with_loss, optimizer) _executor.compile(train_network, inputs, label) + + +def test_spares_proximal_ada_grad_compile(): + """ test sparse proximal_ada_grad compile """ + indices = Tensor(np.array([0, 1]).astype(np.int32)) + label = Tensor(np.zeros([2, 1, 2]).astype(np.float32)) + net = NetWithSparseGatherV2() + net.set_train() + + optimizer = ProximalAdagrad(net.trainable_params(), loss_scale=2.0) + train_network = TrainOneStepCell(net, optimizer) + _executor.compile(train_network, indices, label)