forked from mindspore-Ecosystem/mindspore
Define the default decay_filter for `Adam` optimizer.
This commit is contained in:
parent
b2b1adff2f
commit
ead50a2170
|
@ -166,7 +166,8 @@ class Adam(Optimizer):
|
|||
"""
|
||||
|
||||
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
|
||||
use_nesterov=False, weight_decay=0.0, loss_scale=1.0):
|
||||
use_nesterov=False, weight_decay=0.0, loss_scale=1.0,
|
||||
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
|
||||
super(Adam, self).__init__(learning_rate, params)
|
||||
_check_param_value(beta1, beta2, eps, weight_decay)
|
||||
validator.check_type("use_locking", use_locking, [bool])
|
||||
|
@ -192,6 +193,7 @@ class Adam(Optimizer):
|
|||
self.moment1 = self.parameters.clone(prefix="moment1", init='zeros')
|
||||
self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
|
||||
|
||||
self.decay_tf = tuple(decay_filter(x) for x in self.parameters)
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.opt = P.Adam(use_locking, use_nesterov)
|
||||
self.weight_decay = weight_decay * loss_scale
|
||||
|
|
Loading…
Reference in New Issue