forked from mindspore-Ecosystem/mindspore
add weight decay in RMSProp optimizer
This commit is contained in:
parent
0c81759ae6
commit
1b4041a8f1
|
@ -18,7 +18,8 @@ from mindspore.common.initializer import initializer
|
|||
from mindspore.common.parameter import Parameter
|
||||
from mindspore._checkparam import ParamValidator as validator
|
||||
import mindspore.common.dtype as mstype
|
||||
from .optimizer import Optimizer, grad_scale
|
||||
from mindspore.common import Tensor
|
||||
from .optimizer import Optimizer, grad_scale, apply_decay
|
||||
|
||||
rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
|
||||
centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
|
||||
|
@ -118,6 +119,9 @@ class RMSProp(Optimizer):
|
|||
use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False.
|
||||
centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False
|
||||
loss_scale (float): A floating point value for the loss scale. Default: 1.0.
|
||||
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
|
||||
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
|
||||
lambda x: 'beta' not in x.name and 'gamma' not in x.name.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
|
@ -132,7 +136,8 @@ class RMSProp(Optimizer):
|
|||
>>> model = Model(net, loss, opt)
|
||||
"""
|
||||
def __init__(self, params, learning_rate=0.1, decay=0.9, momentum=0.0, epsilon=1e-10,
|
||||
use_locking=False, centered=False, loss_scale=1.0):
|
||||
use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0,
|
||||
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
|
||||
super(RMSProp, self).__init__(learning_rate, params)
|
||||
|
||||
if isinstance(momentum, float) and momentum < 0.0:
|
||||
|
@ -159,6 +164,7 @@ class RMSProp(Optimizer):
|
|||
self.assignadd = P.AssignAdd()
|
||||
self.global_step = Parameter(initializer(0, [1], mstype.int32), name="global_step")
|
||||
self.axis = 0
|
||||
self.one = Tensor(1, mstype.int32)
|
||||
|
||||
self.momentum = momentum
|
||||
|
||||
|
@ -167,10 +173,14 @@ class RMSProp(Optimizer):
|
|||
self.hyper_map = C.HyperMap()
|
||||
|
||||
self.decay = decay
|
||||
self.decay_tf = tuple(decay_filter(x) for x in self.parameters)
|
||||
self.reciprocal_scale = 1.0 / loss_scale
|
||||
self.weight_decay = weight_decay * loss_scale
|
||||
|
||||
def construct(self, gradients):
|
||||
params = self.parameters
|
||||
if self.weight_decay > 0:
|
||||
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_tf, params, gradients)
|
||||
if self.reciprocal_scale != 1.0:
|
||||
gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)
|
||||
if self.dynamic_lr:
|
||||
|
|
Loading…
Reference in New Issue