From ead50a21700a0204589ac23e58bf0caf5b7498c2 Mon Sep 17 00:00:00 2001 From: seatea Date: Thu, 9 Apr 2020 18:24:21 +0800 Subject: [PATCH] Define the default decay_filter for `Adam` optimizer. --- mindspore/nn/optim/adam.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index 86ce2b21472..521510fa585 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -166,7 +166,8 @@ class Adam(Optimizer): """ def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False, - use_nesterov=False, weight_decay=0.0, loss_scale=1.0): + use_nesterov=False, weight_decay=0.0, loss_scale=1.0, + decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): super(Adam, self).__init__(learning_rate, params) _check_param_value(beta1, beta2, eps, weight_decay) validator.check_type("use_locking", use_locking, [bool]) @@ -192,6 +193,7 @@ class Adam(Optimizer): self.moment1 = self.parameters.clone(prefix="moment1", init='zeros') self.moment2 = self.parameters.clone(prefix="moment2", init='zeros') + self.decay_tf = tuple(decay_filter(x) for x in self.parameters) self.hyper_map = C.HyperMap() self.opt = P.Adam(use_locking, use_nesterov) self.weight_decay = weight_decay * loss_scale