enable graph kernel when training retinaface_resnet50 on GPU
This commit is contained in:
parent
c96d4269b1
commit
5db63a703f
2
akg
2
akg
|
@ -1 +1 @@
|
||||||
Subproject commit 5dbebd8613b97bc6723bf8b29ee5ab480dfd6110
|
Subproject commit 67533ceaa0f0d2e5f9728c25b8c290ec15b56ef4
|
|
@ -33,6 +33,9 @@ from src.lr_schedule import adjust_learning_rate
|
||||||
def train(cfg):
|
def train(cfg):
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU', save_graphs=False)
|
context.set_context(mode=context.GRAPH_MODE, device_target='GPU', save_graphs=False)
|
||||||
|
if context.get_context("device_target") == "GPU":
|
||||||
|
# Enable graph kernel
|
||||||
|
context.set_context(enable_graph_kernel=True, graph_kernel_flags="--enable_parallel_fusion")
|
||||||
if cfg['ngpu'] > 1:
|
if cfg['ngpu'] > 1:
|
||||||
init("nccl")
|
init("nccl")
|
||||||
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
|
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
|
|
Loading…
Reference in New Issue