diff --git a/model_zoo/official/cv/resnext50/train.py b/model_zoo/official/cv/resnext50/train.py index 27e782da9c1..7a7aaa61572 100644 --- a/model_zoo/official/cv/resnext50/train.py +++ b/model_zoo/official/cv/resnext50/train.py @@ -147,6 +147,8 @@ def parse_args(cloud_args=None): args.lr_epochs = list(map(int, args.lr_epochs.split(','))) args.image_size = list(map(int, args.image_size.split(','))) + context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, + device_target=args.platform, save_graphs=False) # init distributed if args.is_distributed: init() @@ -190,8 +192,6 @@ def merge_args(args, cloud_args): def train(cloud_args=None): """training process""" args = parse_args(cloud_args) - context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, - device_target=args.platform, save_graphs=False) if os.getenv('DEVICE_ID', "not_set").isdigit(): context.set_context(device_id=int(os.getenv('DEVICE_ID')))