forked from mindspore-Ecosystem/mindspore
enable graph kernl when training resnet101 on GPU back-end
This commit is contained in:
parent
64906321fc
commit
84bff86eba
|
@ -116,6 +116,10 @@ def apply_eval(eval_param):
|
||||||
res = eval_model.eval(eval_ds)
|
res = eval_model.eval(eval_ds)
|
||||||
return res[metrics_name]
|
return res[metrics_name]
|
||||||
|
|
||||||
|
def set_graph_kernel_context(run_platform, net_name):
|
||||||
|
if run_platform == "GPU" and net_name == "resnet101":
|
||||||
|
context.set_context(enable_graph_kernel=True, graph_kernel_flags="--enable_parallel_fusion")
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
target = args_opt.device_target
|
target = args_opt.device_target
|
||||||
if target == "CPU":
|
if target == "CPU":
|
||||||
|
@ -126,6 +130,7 @@ if __name__ == '__main__':
|
||||||
# init context
|
# init context
|
||||||
if args_opt.mode == 'GRAPH':
|
if args_opt.mode == 'GRAPH':
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
|
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
|
||||||
|
set_graph_kernel_context(target, args_opt.net)
|
||||||
else:
|
else:
|
||||||
context.set_context(mode=context.PYNATIVE_MODE, device_target=target, save_graphs=False)
|
context.set_context(mode=context.PYNATIVE_MODE, device_target=target, save_graphs=False)
|
||||||
if args_opt.parameter_server:
|
if args_opt.parameter_server:
|
||||||
|
|
Loading…
Reference in New Issue