forked from mindspore-Ecosystem/mindspore
fix bug in sparse proximal ada grad
This commit is contained in:
parent
e09d50e4d6
commit
19762375a5
|
@ -24,10 +24,22 @@ from .._checkparam import Rel
|
|||
|
||||
|
||||
class LearningRateSchedule(Cell):
|
||||
"""Basic class of learning rate schedule."""
|
||||
def __init__(self):
|
||||
super(LearningRateSchedule, self).__init__()
|
||||
|
||||
def construct(self, global_step):
|
||||
"""
|
||||
Defines the computation to get the current learning rate.
|
||||
|
||||
This method should be overridden by all subclasses.
|
||||
|
||||
Note:
|
||||
The output should be a Tensor of scalar.
|
||||
|
||||
Inputs:
|
||||
Tensor. The current step number.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ _proximal_ada_grad_opt = C.MultitypeFuncGraph("proximal_ada_grad_opt")
|
|||
|
||||
@_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "IndexedSlices", "Tensor",
|
||||
"Tensor")
|
||||
def _tensor_run_opt_with_sparse(opt, sparse_opt, learning_rate, l1, l2, gradient, weight, accum):
|
||||
def _tensor_run_opt_with_sparse(opt, sparse_opt, l1, l2, learning_rate, gradient, weight, accum):
|
||||
"""Apply sparse proximal_ada_grad optimizer to the weight parameter."""
|
||||
success = True
|
||||
success = F.depend(success, sparse_opt(weight, accum, learning_rate, l1, l2, gradient.values(), gradient.indices()))
|
||||
|
|
Loading…
Reference in New Issue