forked from mindspore-Ecosystem/mindspore
!9454 fix hccl init in resnext50
From: @zhouyaqiang0 Reviewed-by: @c_34,@guoqi1024 Signed-off-by: @c_34
This commit is contained in:
commit
a6766001ec
|
@ -147,6 +147,8 @@ def parse_args(cloud_args=None):
|
|||
args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
|
||||
args.image_size = list(map(int, args.image_size.split(',')))
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
||||
device_target=args.platform, save_graphs=False)
|
||||
# init distributed
|
||||
if args.is_distributed:
|
||||
init()
|
||||
|
@ -190,8 +192,6 @@ def merge_args(args, cloud_args):
|
|||
def train(cloud_args=None):
|
||||
"""training process"""
|
||||
args = parse_args(cloud_args)
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
||||
device_target=args.platform, save_graphs=False)
|
||||
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
||||
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
||||
|
||||
|
|
Loading…
Reference in New Issue