forked from mindspore-Ecosystem/mindspore
fix fcns2
This commit is contained in:
parent
dafd2713ac
commit
d8bf9ab8da
|
@ -49,12 +49,14 @@ def train():
|
|||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False,
|
||||
device_target="Ascend", device_id=args.device_id)
|
||||
# init multicards training
|
||||
args.rank = 0
|
||||
args.group_size = 1
|
||||
if device_num > 1:
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=device_num)
|
||||
init()
|
||||
args.rank = get_rank()
|
||||
args.group_size = get_group_size()
|
||||
args.rank = get_rank()
|
||||
args.group_size = get_group_size()
|
||||
|
||||
# dataset
|
||||
dataset = data_generator.SegDataset(image_mean=cfg.image_mean,
|
||||
|
|
Loading…
Reference in New Issue