forked from mindspore-Ecosystem/mindspore
!16817 open graph_kernel flag in mobilenetv2 when platform is GPU
From: @zengzitao Reviewed-by: @gaoxiong1,@anyrenwei Signed-off-by: @anyrenwei
This commit is contained in:
commit
a83b9f9a83
|
@ -19,6 +19,7 @@ import time
|
|||
import random
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import WithLossCell, TrainOneStepCell
|
||||
from mindspore.nn.optim.momentum import Momentum
|
||||
|
@ -56,6 +57,8 @@ if __name__ == '__main__':
|
|||
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, config=config,
|
||||
enable_cache=args_opt.enable_cache, cache_session_id=args_opt.cache_session_id)
|
||||
step_size = dataset.get_dataset_size()
|
||||
if config.platform == "GPU":
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
if args_opt.pretrain_ckpt:
|
||||
if args_opt.freeze_layer == "backbone":
|
||||
load_ckpt(backbone_net, args_opt.pretrain_ckpt, trainable=False)
|
||||
|
@ -91,8 +94,8 @@ if __name__ == '__main__':
|
|||
|
||||
if args_opt.pretrain_ckpt == "" or args_opt.freeze_layer != "backbone":
|
||||
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, \
|
||||
config.weight_decay, config.loss_scale)
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
|
||||
config.weight_decay, config.loss_scale)
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale)
|
||||
|
||||
cb = config_ckpoint(config, lr, step_size)
|
||||
|
@ -101,7 +104,8 @@ if __name__ == '__main__':
|
|||
print("============== End Training ==============")
|
||||
|
||||
else:
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, head_net.get_parameters()), lr, config.momentum, config.weight_decay)
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, head_net.get_parameters()),
|
||||
lr, config.momentum, config.weight_decay)
|
||||
|
||||
network = WithLossCell(head_net, loss)
|
||||
network = TrainOneStepCell(network, opt)
|
||||
|
|
Loading…
Reference in New Issue