enable graph kernl when training resnet101 on GPU back-end

This commit is contained in:
looop5 2021-05-24 09:39:08 +08:00
parent 64906321fc
commit 84bff86eba
1 changed files with 5 additions and 0 deletions

View File

@ -116,6 +116,10 @@ def apply_eval(eval_param):
res = eval_model.eval(eval_ds) res = eval_model.eval(eval_ds)
return res[metrics_name] return res[metrics_name]
def set_graph_kernel_context(run_platform, net_name):
if run_platform == "GPU" and net_name == "resnet101":
context.set_context(enable_graph_kernel=True, graph_kernel_flags="--enable_parallel_fusion")
if __name__ == '__main__': if __name__ == '__main__':
target = args_opt.device_target target = args_opt.device_target
if target == "CPU": if target == "CPU":
@ -126,6 +130,7 @@ if __name__ == '__main__':
# init context # init context
if args_opt.mode == 'GRAPH': if args_opt.mode == 'GRAPH':
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
set_graph_kernel_context(target, args_opt.net)
else: else:
context.set_context(mode=context.PYNATIVE_MODE, device_target=target, save_graphs=False) context.set_context(mode=context.PYNATIVE_MODE, device_target=target, save_graphs=False)
if args_opt.parameter_server: if args_opt.parameter_server: