forked from mindspore-Ecosystem/mindspore
fix fcns
This commit is contained in:
parent
1965ecb9a1
commit
b445dab0f6
|
@ -46,6 +46,8 @@ def train():
|
|||
args = parse_args()
|
||||
cfg = FCN8s_VOC2012_cfg
|
||||
device_num = int(os.environ.get("DEVICE_NUM", 1))
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False,
|
||||
device_target="Ascend", device_id=args.device_id)
|
||||
# init multicards training
|
||||
if device_num > 1:
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
|
@ -54,9 +56,6 @@ def train():
|
|||
args.rank = get_rank()
|
||||
args.group_size = get_group_size()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False,
|
||||
device_target="Ascend", device_id=args.device_id)
|
||||
|
||||
# dataset
|
||||
dataset = data_generator.SegDataset(image_mean=cfg.image_mean,
|
||||
image_std=cfg.image_std,
|
||||
|
|
Loading…
Reference in New Issue