From 3805f2a4dee27903d8eddbd067702f496775e9bc Mon Sep 17 00:00:00 2001 From: guoqi Date: Sat, 25 Jul 2020 16:42:44 +0800 Subject: [PATCH] fix a typo in modelzoo/resnet --- model_zoo/official/cv/resnet/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/model_zoo/official/cv/resnet/train.py b/model_zoo/official/cv/resnet/train.py index 5f4a1423837..be1c6290b1f 100755 --- a/model_zoo/official/cv/resnet/train.py +++ b/model_zoo/official/cv/resnet/train.py @@ -127,11 +127,11 @@ if __name__ == '__main__': lr = Tensor(lr) # define opt - decayed_params = list(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name, net.trainalbe_params())) - no_decayed_params = [param for param in net.trainalbe_params() if param not in decayed_params] + decayed_params = list(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name, net.trainable_params())) + no_decayed_params = [param for param in net.trainable_params() if param not in decayed_params] group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay}, {'params': no_decayed_params}, - {'order_params': net.trainalbe_params()}] + {'order_params': net.trainable_params()}] opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale) # define loss, model if target == "Ascend":