forked from mindspore-Ecosystem/mindspore
fix a typo in modelzoo/resnet
This commit is contained in:
parent
7fbed0ce94
commit
3805f2a4de
|
@ -127,11 +127,11 @@ if __name__ == '__main__':
|
|||
lr = Tensor(lr)
|
||||
|
||||
# define opt
|
||||
decayed_params = list(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name, net.trainalbe_params()))
|
||||
no_decayed_params = [param for param in net.trainalbe_params() if param not in decayed_params]
|
||||
decayed_params = list(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name, net.trainable_params()))
|
||||
no_decayed_params = [param for param in net.trainable_params() if param not in decayed_params]
|
||||
group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
|
||||
{'params': no_decayed_params},
|
||||
{'order_params': net.trainalbe_params()}]
|
||||
{'order_params': net.trainable_params()}]
|
||||
opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale)
|
||||
# define loss, model
|
||||
if target == "Ascend":
|
||||
|
|
Loading…
Reference in New Issue