enable graph kernel for default model ssd300 on GPU back-end

This commit is contained in:
looop5 2021-05-20 09:28:15 +08:00
parent 408f686970
commit 8ba47a92e3
1 changed files with 6 additions and 0 deletions

View File

@ -114,6 +114,11 @@ def ssd_model_build(args_opt):
raise ValueError(f'config.model: {config.model} is not supported')
return ssd
def set_graph_kernel_context(run_platform, model):
if run_platform == "GPU" and model == "ssd300":
# Enable graph kernel for default model ssd300 on GPU back-end.
context.set_context(enable_graph_kernel=True, graph_kernel_flags="--enable_parallel_fusion")
def main():
args_opt = get_args()
rank = 0
@ -122,6 +127,7 @@ def main():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
else:
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform, device_id=args_opt.device_id)
set_graph_kernel_context(args_opt.run_platform, config.model)
if args_opt.distribute:
device_num = args_opt.device_num
context.reset_auto_parallel_context()