From 2b0550b77abaa397680c8a553916594cd6be4cd9 Mon Sep 17 00:00:00 2001 From: panfengfeng Date: Tue, 10 Nov 2020 15:42:46 +0800 Subject: [PATCH] update nasnet scripts --- model_zoo/official/cv/nasnet/train.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/model_zoo/official/cv/nasnet/train.py b/model_zoo/official/cv/nasnet/train.py index ba79c88dbaa..297f25cc6b3 100755 --- a/model_zoo/official/cv/nasnet/train.py +++ b/model_zoo/official/cv/nasnet/train.py @@ -29,7 +29,7 @@ from mindspore.common import dtype as mstype from src.config import nasnet_a_mobile_config_gpu as cfg from src.dataset import create_dataset -from src.nasnet_a_mobile import NASNetAMobileWithLoss, NASNetAMobileTrainOneStepWithClipGradient +from src.nasnet_a_mobile import NASNetAMobileWithLoss from src.lr_generator import get_lr @@ -104,9 +104,13 @@ if __name__ == '__main__': optimizer = RMSProp(group_params, lr, decay=cfg.rmsprop_decay, weight_decay=cfg.weight_decay, momentum=cfg.momentum, epsilon=cfg.opt_eps, loss_scale=cfg.loss_scale) - net_with_grads = NASNetAMobileTrainOneStepWithClipGradient(net_with_loss, optimizer) - net_with_grads.set_train() - model = Model(net_with_grads) + # net_with_grads = NASNetAMobileTrainOneStepWithClipGradient(net_with_loss, optimizer) + # net_with_grads.set_train() + # model = Model(net_with_grads) + + # high performance + net_with_loss.set_train() + model = Model(net_with_loss, optimizer=optimizer) print("============== Starting Training ==============") loss_cb = LossMonitor(per_print_times=batches_per_epoch)