update train,py for resent

This commit is contained in:
VectorSL 2021-03-25 10:43:35 +08:00
parent 4b329090b6
commit 27f32783cd
1 changed files with 13 additions and 33 deletions

View File

@ -177,8 +177,6 @@ if __name__ == '__main__':
{'params': no_decayed_params},
{'order_params': net.trainable_params()}]
opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale)
# define loss, model
if target == "Ascend":
if args_opt.dataset == "imagenet2012":
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
@ -189,27 +187,9 @@ if __name__ == '__main__':
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=False)
else:
# GPU and CPU target
if args_opt.dataset == "imagenet2012":
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropySmooth(sparse=True, reduction="mean",
smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
else:
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
if (args_opt.net == "resnet101" or args_opt.net == "resnet50") and \
not args_opt.parameter_server and target != "CPU":
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay,
config.loss_scale)
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
# Mixed precision
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=False)
else:
if (args_opt.net != "resnet101" and args_opt.net != "resnet50") or \
args_opt.parameter_server or target == "CPU":
## fp32 training
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
if cfg.optimizer == "Thor" and args_opt.dataset == "imagenet2012":
from src.lr_generator import get_thor_damping