fix bug in sparse proximal ada grad

This commit is contained in:
wangnan39@huawei.com 2020-07-22 15:50:02 +08:00
parent e09d50e4d6
commit 19762375a5
2 changed files with 13 additions and 1 deletions

View File

@ -24,10 +24,22 @@ from .._checkparam import Rel
class LearningRateSchedule(Cell):
"""Basic class of learning rate schedule."""
def __init__(self):
super(LearningRateSchedule, self).__init__()
def construct(self, global_step):
"""
Defines the computation to get the current learning rate.
This method should be overridden by all subclasses.
Note:
The output should be a Tensor of scalar.
Inputs:
Tensor. The current step number.
"""
raise NotImplementedError

View File

@ -24,7 +24,7 @@ _proximal_ada_grad_opt = C.MultitypeFuncGraph("proximal_ada_grad_opt")
@_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "IndexedSlices", "Tensor",
"Tensor")
def _tensor_run_opt_with_sparse(opt, sparse_opt, learning_rate, l1, l2, gradient, weight, accum):
def _tensor_run_opt_with_sparse(opt, sparse_opt, l1, l2, learning_rate, gradient, weight, accum):
"""Apply sparse proximal_ada_grad optimizer to the weight parameter."""
success = True
success = F.depend(success, sparse_opt(weight, accum, learning_rate, l1, l2, gradient.values(), gradient.indices()))