!3494 fix a typo in resnet

Merge pull request !3494 from guoqi/master
This commit is contained in:
mindspore-ci-bot 2020-07-27 09:47:08 +08:00 committed by Gitee
commit 843cf1fcbb
1 changed files with 3 additions and 3 deletions

View File

@ -127,11 +127,11 @@ if __name__ == '__main__':
lr = Tensor(lr) lr = Tensor(lr)
# define opt # define opt
decayed_params = list(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name, net.trainalbe_params())) decayed_params = list(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name, net.trainable_params()))
no_decayed_params = [param for param in net.trainalbe_params() if param not in decayed_params] no_decayed_params = [param for param in net.trainable_params() if param not in decayed_params]
group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay}, group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
{'params': no_decayed_params}, {'params': no_decayed_params},
{'order_params': net.trainalbe_params()}] {'order_params': net.trainable_params()}]
opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale) opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale)
# define loss, model # define loss, model
if target == "Ascend": if target == "Ascend":