forked from mindspore-Ecosystem/mindspore
update train,py for resent
This commit is contained in:
parent
4b329090b6
commit
27f32783cd
|
@ -177,8 +177,6 @@ if __name__ == '__main__':
|
|||
{'params': no_decayed_params},
|
||||
{'order_params': net.trainable_params()}]
|
||||
opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale)
|
||||
# define loss, model
|
||||
if target == "Ascend":
|
||||
if args_opt.dataset == "imagenet2012":
|
||||
if not config.use_label_smooth:
|
||||
config.label_smooth_factor = 0.0
|
||||
|
@ -189,27 +187,9 @@ if __name__ == '__main__':
|
|||
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
|
||||
amp_level="O2", keep_batchnorm_fp32=False)
|
||||
else:
|
||||
# GPU and CPU target
|
||||
if args_opt.dataset == "imagenet2012":
|
||||
if not config.use_label_smooth:
|
||||
config.label_smooth_factor = 0.0
|
||||
loss = CrossEntropySmooth(sparse=True, reduction="mean",
|
||||
smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
|
||||
else:
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
|
||||
if (args_opt.net == "resnet101" or args_opt.net == "resnet50") and \
|
||||
not args_opt.parameter_server and target != "CPU":
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay,
|
||||
config.loss_scale)
|
||||
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
# Mixed precision
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
|
||||
amp_level="O2", keep_batchnorm_fp32=False)
|
||||
else:
|
||||
if (args_opt.net != "resnet101" and args_opt.net != "resnet50") or \
|
||||
args_opt.parameter_server or target == "CPU":
|
||||
## fp32 training
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay)
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
|
||||
if cfg.optimizer == "Thor" and args_opt.dataset == "imagenet2012":
|
||||
from src.lr_generator import get_thor_damping
|
||||
|
|
Loading…
Reference in New Issue