forked from mindspore-Ecosystem/mindspore
fix fcns2
This commit is contained in:
parent
dafd2713ac
commit
d8bf9ab8da
|
@ -49,12 +49,14 @@ def train():
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False,
|
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False,
|
||||||
device_target="Ascend", device_id=args.device_id)
|
device_target="Ascend", device_id=args.device_id)
|
||||||
# init multicards training
|
# init multicards training
|
||||||
|
args.rank = 0
|
||||||
|
args.group_size = 1
|
||||||
if device_num > 1:
|
if device_num > 1:
|
||||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=device_num)
|
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=device_num)
|
||||||
init()
|
init()
|
||||||
args.rank = get_rank()
|
args.rank = get_rank()
|
||||||
args.group_size = get_group_size()
|
args.group_size = get_group_size()
|
||||||
|
|
||||||
# dataset
|
# dataset
|
||||||
dataset = data_generator.SegDataset(image_mean=cfg.image_mean,
|
dataset = data_generator.SegDataset(image_mean=cfg.image_mean,
|
||||||
|
|
Loading…
Reference in New Issue