From 84bff86eba7800bb88c991745dcfe899975300b8 Mon Sep 17 00:00:00 2001 From: looop5 Date: Mon, 24 May 2021 09:39:08 +0800 Subject: [PATCH] enable graph kernl when training resnet101 on GPU back-end --- model_zoo/official/cv/resnet/train.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/model_zoo/official/cv/resnet/train.py b/model_zoo/official/cv/resnet/train.py index 1c6fb2aa572..9ab2d7e44f1 100755 --- a/model_zoo/official/cv/resnet/train.py +++ b/model_zoo/official/cv/resnet/train.py @@ -116,6 +116,10 @@ def apply_eval(eval_param): res = eval_model.eval(eval_ds) return res[metrics_name] +def set_graph_kernel_context(run_platform, net_name): + if run_platform == "GPU" and net_name == "resnet101": + context.set_context(enable_graph_kernel=True, graph_kernel_flags="--enable_parallel_fusion") + if __name__ == '__main__': target = args_opt.device_target if target == "CPU": @@ -126,6 +130,7 @@ if __name__ == '__main__': # init context if args_opt.mode == 'GRAPH': context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) + set_graph_kernel_context(target, args_opt.net) else: context.set_context(mode=context.PYNATIVE_MODE, device_target=target, save_graphs=False) if args_opt.parameter_server: