forked from mindspore-Ecosystem/mindspore
enable graph kernel for default model ssd300 on GPU back-end
This commit is contained in:
parent
408f686970
commit
8ba47a92e3
|
@ -114,6 +114,11 @@ def ssd_model_build(args_opt):
|
|||
raise ValueError(f'config.model: {config.model} is not supported')
|
||||
return ssd
|
||||
|
||||
def set_graph_kernel_context(run_platform, model):
|
||||
if run_platform == "GPU" and model == "ssd300":
|
||||
# Enable graph kernel for default model ssd300 on GPU back-end.
|
||||
context.set_context(enable_graph_kernel=True, graph_kernel_flags="--enable_parallel_fusion")
|
||||
|
||||
def main():
|
||||
args_opt = get_args()
|
||||
rank = 0
|
||||
|
@ -122,6 +127,7 @@ def main():
|
|||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
else:
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform, device_id=args_opt.device_id)
|
||||
set_graph_kernel_context(args_opt.run_platform, config.model)
|
||||
if args_opt.distribute:
|
||||
device_num = args_opt.device_num
|
||||
context.reset_auto_parallel_context()
|
||||
|
|
Loading…
Reference in New Issue