From 54a22e7e781c37b9ea63fc55920988658137e887 Mon Sep 17 00:00:00 2001 From: chendongsheng Date: Mon, 19 Apr 2021 15:30:50 +0800 Subject: [PATCH] fixed not support ps float16 --- model_zoo/official/cv/resnet/train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/model_zoo/official/cv/resnet/train.py b/model_zoo/official/cv/resnet/train.py index f875149dba0..2076924fa72 100755 --- a/model_zoo/official/cv/resnet/train.py +++ b/model_zoo/official/cv/resnet/train.py @@ -207,12 +207,14 @@ if __name__ == '__main__': metrics = {"acc"} if args_opt.run_distribute: metrics = {'acc': DistAccuracy(batch_size=config.batch_size, device_num=args_opt.device_num)} - model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics=metrics, - amp_level="O2", keep_batchnorm_fp32=False, eval_network=dist_eval_network) if (args_opt.net != "resnet101" and args_opt.net != "resnet50") or \ args_opt.parameter_server or target == "CPU": ## fp32 training model = Model(net, loss_fn=loss, optimizer=opt, metrics=metrics, eval_network=dist_eval_network) + else: + model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics=metrics, + amp_level="O2", keep_batchnorm_fp32=False, eval_network=dist_eval_network) + if cfg.optimizer == "Thor" and args_opt.dataset == "imagenet2012": from src.lr_generator import get_thor_damping damping = get_thor_damping(0, config.damping_init, config.damping_decay, 70, step_size)