From: @wukesong
Reviewed-by: @oacjiewen,@liangchenghui
Signed-off-by: @liangchenghui
This commit is contained in:
mindspore-ci-bot 2021-03-22 09:39:38 +08:00 committed by Gitee
commit 452016fda9
1 changed files with 4 additions and 2 deletions

View File

@ -49,6 +49,8 @@ def train():
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False,
device_target="Ascend", device_id=args.device_id)
# init multicards training
args.rank = 0
args.group_size = 1
if device_num > 1:
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=device_num)