enable graph kernel

This commit is contained in:
liuyihong 2021-05-28 12:40:03 +08:00
parent 63ba14b5c6
commit 143009a51b
1 changed files with 2 additions and 0 deletions

View File

@ -46,6 +46,8 @@ if __name__ == '__main__':
config = train_config
data_sink = (args.device_target == "GPU")
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=False)
if args.device_target == "GPU":
context.set_context(enable_graph_kernel=True)
if args.is_distributed:
init('nccl')
rank_id = get_rank()