From 5db63a703fb341efd87f68fbc11e28c2dc826cae Mon Sep 17 00:00:00 2001 From: looop5 Date: Mon, 7 Jun 2021 11:29:54 +0800 Subject: [PATCH] enable graph kernel when training retinaface_resnet50 on GPU --- akg | 2 +- model_zoo/official/cv/retinaface_resnet50/train.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/akg b/akg index 5dbebd8613b..67533ceaa0f 160000 --- a/akg +++ b/akg @@ -1 +1 @@ -Subproject commit 5dbebd8613b97bc6723bf8b29ee5ab480dfd6110 +Subproject commit 67533ceaa0f0d2e5f9728c25b8c290ec15b56ef4 diff --git a/model_zoo/official/cv/retinaface_resnet50/train.py b/model_zoo/official/cv/retinaface_resnet50/train.py index 550d8cbb721..bb9a2f5c813 100644 --- a/model_zoo/official/cv/retinaface_resnet50/train.py +++ b/model_zoo/official/cv/retinaface_resnet50/train.py @@ -33,6 +33,9 @@ from src.lr_schedule import adjust_learning_rate def train(cfg): context.set_context(mode=context.GRAPH_MODE, device_target='GPU', save_graphs=False) + if context.get_context("device_target") == "GPU": + # Enable graph kernel + context.set_context(enable_graph_kernel=True, graph_kernel_flags="--enable_parallel_fusion") if cfg['ngpu'] > 1: init("nccl") context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,