forked from mindspore-Ecosystem/mindspore
enable graph kernl when training resnet101 on GPU back-end
This commit is contained in:
parent
64906321fc
commit
84bff86eba
|
@ -116,6 +116,10 @@ def apply_eval(eval_param):
|
|||
res = eval_model.eval(eval_ds)
|
||||
return res[metrics_name]
|
||||
|
||||
def set_graph_kernel_context(run_platform, net_name):
|
||||
if run_platform == "GPU" and net_name == "resnet101":
|
||||
context.set_context(enable_graph_kernel=True, graph_kernel_flags="--enable_parallel_fusion")
|
||||
|
||||
if __name__ == '__main__':
|
||||
target = args_opt.device_target
|
||||
if target == "CPU":
|
||||
|
@ -126,6 +130,7 @@ if __name__ == '__main__':
|
|||
# init context
|
||||
if args_opt.mode == 'GRAPH':
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
|
||||
set_graph_kernel_context(target, args_opt.net)
|
||||
else:
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target=target, save_graphs=False)
|
||||
if args_opt.parameter_server:
|
||||
|
|
Loading…
Reference in New Issue