fix fcns2

This commit is contained in:
wukesong 2021-03-20 15:51:09 +08:00
parent dafd2713ac
commit d8bf9ab8da
1 changed files with 4 additions and 2 deletions

View File

@ -49,12 +49,14 @@ def train():
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False, context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False,
device_target="Ascend", device_id=args.device_id) device_target="Ascend", device_id=args.device_id)
# init multicards training # init multicards training
args.rank = 0
args.group_size = 1
if device_num > 1: if device_num > 1:
parallel_mode = ParallelMode.DATA_PARALLEL parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=device_num) context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=device_num)
init() init()
args.rank = get_rank() args.rank = get_rank()
args.group_size = get_group_size() args.group_size = get_group_size()
# dataset # dataset
dataset = data_generator.SegDataset(image_mean=cfg.image_mean, dataset = data_generator.SegDataset(image_mean=cfg.image_mean,