!16817 open graph_kernel flag in mobilenetv2 when platform is GPU

From: @zengzitao
Reviewed-by: @gaoxiong1,@anyrenwei
Signed-off-by: @anyrenwei
This commit is contained in:
mindspore-ci-bot 2021-05-25 10:15:03 +08:00 committed by Gitee
commit a83b9f9a83
1 changed files with 7 additions and 3 deletions

View File

@ -19,6 +19,7 @@ import time
import random
import numpy as np
from mindspore import context
from mindspore import Tensor
from mindspore.nn import WithLossCell, TrainOneStepCell
from mindspore.nn.optim.momentum import Momentum
@ -56,6 +57,8 @@ if __name__ == '__main__':
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, config=config,
enable_cache=args_opt.enable_cache, cache_session_id=args_opt.cache_session_id)
step_size = dataset.get_dataset_size()
if config.platform == "GPU":
context.set_context(enable_graph_kernel=True)
if args_opt.pretrain_ckpt:
if args_opt.freeze_layer == "backbone":
load_ckpt(backbone_net, args_opt.pretrain_ckpt, trainable=False)
@ -91,8 +94,8 @@ if __name__ == '__main__':
if args_opt.pretrain_ckpt == "" or args_opt.freeze_layer != "backbone":
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, \
config.weight_decay, config.loss_scale)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
config.weight_decay, config.loss_scale)
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale)
cb = config_ckpoint(config, lr, step_size)
@ -101,7 +104,8 @@ if __name__ == '__main__':
print("============== End Training ==============")
else:
opt = Momentum(filter(lambda x: x.requires_grad, head_net.get_parameters()), lr, config.momentum, config.weight_decay)
opt = Momentum(filter(lambda x: x.requires_grad, head_net.get_parameters()),
lr, config.momentum, config.weight_decay)
network = WithLossCell(head_net, loss)
network = TrainOneStepCell(network, opt)