diff --git a/model_zoo/official/cv/vgg16/train.py b/model_zoo/official/cv/vgg16/train.py index 29bc344e574..57cc9da59cc 100644 --- a/model_zoo/official/cv/vgg16/train.py +++ b/model_zoo/official/cv/vgg16/train.py @@ -126,6 +126,7 @@ def merge_args(args_opt, cloud_args): if __name__ == '__main__': args = parse_args() + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) device_num = int(os.environ.get("DEVICE_NUM", 1)) if args.is_distributed: if args.device_target == "Ascend": @@ -143,7 +144,6 @@ if __name__ == '__main__': else: if args.device_target == "Ascend": context.set_context(device_id=args.device_id) - context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) # select for master rank save ckpt or all rank save, compatible for model parallel args.rank_save_ckpt_flag = 0