From 19762375a5bb579c82fad73c650e92bce985b97b Mon Sep 17 00:00:00 2001 From: "wangnan39@huawei.com" Date: Wed, 22 Jul 2020 15:50:02 +0800 Subject: [PATCH] fix bug in sparse proximal ada grad --- mindspore/nn/learning_rate_schedule.py | 12 ++++++++++++ mindspore/nn/optim/proximal_ada_grad.py | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/mindspore/nn/learning_rate_schedule.py b/mindspore/nn/learning_rate_schedule.py index b8cde1673b0..181db58e449 100644 --- a/mindspore/nn/learning_rate_schedule.py +++ b/mindspore/nn/learning_rate_schedule.py @@ -24,10 +24,22 @@ from .._checkparam import Rel class LearningRateSchedule(Cell): + """Basic class of learning rate schedule.""" def __init__(self): super(LearningRateSchedule, self).__init__() def construct(self, global_step): + """ + Defines the computation to get the current learning rate. + + This method should be overridden by all subclasses. + + Note: + The output should be a Tensor of scalar. + + Inputs: + Tensor. The current step number. + """ raise NotImplementedError diff --git a/mindspore/nn/optim/proximal_ada_grad.py b/mindspore/nn/optim/proximal_ada_grad.py index 2ef320fd9c3..616f070d32d 100644 --- a/mindspore/nn/optim/proximal_ada_grad.py +++ b/mindspore/nn/optim/proximal_ada_grad.py @@ -24,7 +24,7 @@ _proximal_ada_grad_opt = C.MultitypeFuncGraph("proximal_ada_grad_opt") @_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "IndexedSlices", "Tensor", "Tensor") -def _tensor_run_opt_with_sparse(opt, sparse_opt, learning_rate, l1, l2, gradient, weight, accum): +def _tensor_run_opt_with_sparse(opt, sparse_opt, l1, l2, learning_rate, gradient, weight, accum): """Apply sparse proximal_ada_grad optimizer to the weight parameter.""" success = True success = F.depend(success, sparse_opt(weight, accum, learning_rate, l1, l2, gradient.values(), gradient.indices()))