diff --git a/model_zoo/official/cv/alexnet/export.py b/model_zoo/official/cv/alexnet/export.py index f145abe91f3..3ad95a944d1 100644 --- a/model_zoo/official/cv/alexnet/export.py +++ b/model_zoo/official/cv/alexnet/export.py @@ -38,7 +38,9 @@ parser.add_argument("--file_name", type=str, default="alexnet", help="output fil parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format") args_opt = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) +if args_opt.device_target == "Ascend": + context.set_context(device_id=args_opt.device_id) if __name__ == '__main__': if args_opt.dataset_name == 'cifar10': diff --git a/model_zoo/official/cv/centerface/export.py b/model_zoo/official/cv/centerface/export.py index a54fdb83703..834c4411266 100644 --- a/model_zoo/official/cv/centerface/export.py +++ b/model_zoo/official/cv/centerface/export.py @@ -33,7 +33,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU" help="device target") args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == '__main__': config = ConfigCenterface() diff --git a/model_zoo/official/cv/cnnctc/export.py b/model_zoo/official/cv/cnnctc/export.py index 30ed12a756c..02e33215440 100644 --- a/model_zoo/official/cv/cnnctc/export.py +++ b/model_zoo/official/cv/cnnctc/export.py @@ -31,7 +31,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU" parser.add_argument("--ckpt_file", type=str, default="./ckpts/cnn_ctc.ckpt", help="CNN&CTC ckpt file.") args_opt = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) +if args_opt.device_target == "Ascend": + context.set_context(device_id=args_opt.device_id) if __name__ == "__main__": cfg = Config_CNNCTC() diff --git a/model_zoo/official/cv/deeplabv3/export.py b/model_zoo/official/cv/deeplabv3/export.py index fe436023fa5..7bbbf772397 100644 --- a/model_zoo/official/cv/deeplabv3/export.py +++ b/model_zoo/official/cv/deeplabv3/export.py @@ -34,7 +34,9 @@ parser.add_argument('--model', type=str.lower, default='deeplab_v3_s8', choices= parser.add_argument('--num_classes', type=int, default=21, help='the number of classes (Default: 21)') args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == '__main__': if args.model == 'deeplab_v3_s16': diff --git a/model_zoo/official/cv/densenet121/export.py b/model_zoo/official/cv/densenet121/export.py index ee95b7b18ae..d4211994b3a 100644 --- a/model_zoo/official/cv/densenet121/export.py +++ b/model_zoo/official/cv/densenet121/export.py @@ -33,7 +33,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU" help="device target") args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == "__main__": network = DenseNet121(config.num_classes) diff --git a/model_zoo/official/cv/efficientnet/export.py b/model_zoo/official/cv/efficientnet/export.py index f53dc2b7e6d..66692e28bd2 100644 --- a/model_zoo/official/cv/efficientnet/export.py +++ b/model_zoo/official/cv/efficientnet/export.py @@ -32,7 +32,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU" help="device target") args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == "__main__": if args.device_target != "GPU": diff --git a/model_zoo/official/cv/faster_rcnn/export.py b/model_zoo/official/cv/faster_rcnn/export.py index 79ce15099ee..3296d618c4e 100644 --- a/model_zoo/official/cv/faster_rcnn/export.py +++ b/model_zoo/official/cv/faster_rcnn/export.py @@ -31,7 +31,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU" parser.add_argument('--ckpt_file', type=str, default='', help='fasterrcnn ckpt file.') args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == '__main__': net = FasterRcnn_Infer(config=config) diff --git a/model_zoo/official/cv/googlenet/export.py b/model_zoo/official/cv/googlenet/export.py index 28b260bd056..e701a406d3a 100644 --- a/model_zoo/official/cv/googlenet/export.py +++ b/model_zoo/official/cv/googlenet/export.py @@ -37,7 +37,9 @@ parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['ima help='dataset name.') args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == '__main__': if args.dataset_name == 'cifar10': diff --git a/model_zoo/official/cv/inceptionv3/export.py b/model_zoo/official/cv/inceptionv3/export.py index 9fb06a35a9d..17845bea8b5 100644 --- a/model_zoo/official/cv/inceptionv3/export.py +++ b/model_zoo/official/cv/inceptionv3/export.py @@ -33,7 +33,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU" help="device target") args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == '__main__': net = InceptionV3(num_classes=cfg.num_classes, is_training=False) diff --git a/model_zoo/official/cv/inceptionv4/export.py b/model_zoo/official/cv/inceptionv4/export.py index 8b51c824605..16dcd03f5f0 100644 --- a/model_zoo/official/cv/inceptionv4/export.py +++ b/model_zoo/official/cv/inceptionv4/export.py @@ -34,7 +34,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU" help="device target") args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == '__main__': net = Inceptionv4(classes=config.num_classes) diff --git a/model_zoo/official/cv/lenet/export.py b/model_zoo/official/cv/lenet/export.py index efff69db6cf..440630d49f4 100644 --- a/model_zoo/official/cv/lenet/export.py +++ b/model_zoo/official/cv/lenet/export.py @@ -33,7 +33,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU" help="device target") args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == "__main__": diff --git a/model_zoo/official/cv/maskrcnn/export.py b/model_zoo/official/cv/maskrcnn/export.py index 091822f2d99..103e1925141 100644 --- a/model_zoo/official/cv/maskrcnn/export.py +++ b/model_zoo/official/cv/maskrcnn/export.py @@ -31,7 +31,9 @@ parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], help='device target (default: Ascend)') args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == '__main__': net = MaskRcnn_Infer(config=config) diff --git a/model_zoo/official/cv/maskrcnn_mobilenetv1/export.py b/model_zoo/official/cv/maskrcnn_mobilenetv1/export.py index 23ab45ee4ed..18a5edc8199 100644 --- a/model_zoo/official/cv/maskrcnn_mobilenetv1/export.py +++ b/model_zoo/official/cv/maskrcnn_mobilenetv1/export.py @@ -31,7 +31,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU" help="device target") args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == '__main__': config.test_batch_size = args.batch_size diff --git a/model_zoo/official/cv/mobilenetv2/export.py b/model_zoo/official/cv/mobilenetv2/export.py index 037a1294d0b..f1da6896729 100644 --- a/model_zoo/official/cv/mobilenetv2/export.py +++ b/model_zoo/official/cv/mobilenetv2/export.py @@ -34,7 +34,9 @@ args = parser.parse_args() args.is_training = False args.run_distribute = False -context.set_context(mode=context.GRAPH_MODE, device_target=args.platform, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.platform) +if args.platform == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == '__main__': cfg = set_config(args) diff --git a/model_zoo/official/cv/nasnet/export.py b/model_zoo/official/cv/nasnet/export.py index 9dcf00e70a2..cf460197592 100755 --- a/model_zoo/official/cv/nasnet/export.py +++ b/model_zoo/official/cv/nasnet/export.py @@ -34,7 +34,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU" help="device target") args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == '__main__': net = NASNetAMobile(num_classes=cfg.num_classes, is_training=False) diff --git a/model_zoo/official/cv/psenet/export.py b/model_zoo/official/cv/psenet/export.py index 3adfa109210..5c159be835c 100755 --- a/model_zoo/official/cv/psenet/export.py +++ b/model_zoo/official/cv/psenet/export.py @@ -34,7 +34,9 @@ parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)") args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == '__main__': net = ETSNet(config) diff --git a/model_zoo/official/cv/resnet/export.py b/model_zoo/official/cv/resnet/export.py index 2c3fe625543..14157f6e076 100644 --- a/model_zoo/official/cv/resnet/export.py +++ b/model_zoo/official/cv/resnet/export.py @@ -38,7 +38,9 @@ parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU", "CPU"], help="device target(default: Ascend)") args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == '__main__': diff --git a/model_zoo/official/cv/resnext50/export.py b/model_zoo/official/cv/resnext50/export.py index e59aa5a876b..107a7fc0dcf 100644 --- a/model_zoo/official/cv/resnext50/export.py +++ b/model_zoo/official/cv/resnext50/export.py @@ -33,7 +33,9 @@ parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)") args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == '__main__': net = get_network(num_classes=config.num_classes, platform=args.device_target) diff --git a/model_zoo/official/cv/shufflenetv1/export.py b/model_zoo/official/cv/shufflenetv1/export.py index 93982b0fc22..6e62f637a1e 100644 --- a/model_zoo/official/cv/shufflenetv1/export.py +++ b/model_zoo/official/cv/shufflenetv1/export.py @@ -37,7 +37,9 @@ parser.add_argument('--model_size', type=str, default='2.0x', choices=['2.0x', ' args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == '__main__': diff --git a/model_zoo/official/cv/shufflenetv2/export.py b/model_zoo/official/cv/shufflenetv2/export.py index 840ed610120..7df3ce0bfa0 100644 --- a/model_zoo/official/cv/shufflenetv2/export.py +++ b/model_zoo/official/cv/shufflenetv2/export.py @@ -34,7 +34,9 @@ parser.add_argument("--device_target", type=str, default="GPU", choices=["Ascend", "GPU", "CPU"], help="device where the code will be implemented (default: GPU)") args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == '__main__': if args.device_target != 'GPU': diff --git a/model_zoo/official/cv/squeezenet/export.py b/model_zoo/official/cv/squeezenet/export.py index 6723b953cc5..896aadadb92 100755 --- a/model_zoo/official/cv/squeezenet/export.py +++ b/model_zoo/official/cv/squeezenet/export.py @@ -45,7 +45,9 @@ if args.dataset == "cifar10": else: num_classes = 1000 -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == '__main__': net = squeezenet(num_classes=num_classes) diff --git a/model_zoo/official/cv/ssd/export.py b/model_zoo/official/cv/ssd/export.py index e6f1b2ae3ee..cd31c5adc95 100644 --- a/model_zoo/official/cv/ssd/export.py +++ b/model_zoo/official/cv/ssd/export.py @@ -32,7 +32,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU" help="device target") args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == '__main__': if config.model == "ssd300": diff --git a/model_zoo/official/cv/unet/export.py b/model_zoo/official/cv/unet/export.py index 18f203624d2..fc72b65dd84 100644 --- a/model_zoo/official/cv/unet/export.py +++ b/model_zoo/official/cv/unet/export.py @@ -33,7 +33,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU" help="device target") args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == "__main__": net = UNet(n_channels=cfg["num_channels"], n_classes=cfg["num_classes"]) diff --git a/model_zoo/official/cv/vgg16/export.py b/model_zoo/official/cv/vgg16/export.py index 0d68beaab4d..de837684916 100644 --- a/model_zoo/official/cv/vgg16/export.py +++ b/model_zoo/official/cv/vgg16/export.py @@ -46,7 +46,9 @@ args.batch_norm = cfg.batch_norm args.has_dropout = cfg.has_dropout args.image_size = list(map(int, cfg.image_size.split(','))) -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == '__main__': if args.dataset == "cifar10": diff --git a/model_zoo/official/cv/warpctc/export.py b/model_zoo/official/cv/warpctc/export.py index dd7dfe7babf..6f053da3b47 100644 --- a/model_zoo/official/cv/warpctc/export.py +++ b/model_zoo/official/cv/warpctc/export.py @@ -30,7 +30,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU" help="device target") args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == "__main__": captcha_width = config.captcha_width diff --git a/model_zoo/official/cv/xception/export.py b/model_zoo/official/cv/xception/export.py index 0578fdc7629..ee57ce8a554 100644 --- a/model_zoo/official/cv/xception/export.py +++ b/model_zoo/official/cv/xception/export.py @@ -33,7 +33,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU" help="device target") args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == "__main__": # define net diff --git a/model_zoo/official/cv/yolov3_darknet53/export.py b/model_zoo/official/cv/yolov3_darknet53/export.py index 004d6505b86..8888ea3f3db 100644 --- a/model_zoo/official/cv/yolov3_darknet53/export.py +++ b/model_zoo/official/cv/yolov3_darknet53/export.py @@ -32,7 +32,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU" help="device target") args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == "__main__": network = YOLOV3DarkNet53(is_training=False) diff --git a/model_zoo/official/cv/yolov3_resnet18/export.py b/model_zoo/official/cv/yolov3_resnet18/export.py index dc4858be361..6f402a77b46 100644 --- a/model_zoo/official/cv/yolov3_resnet18/export.py +++ b/model_zoo/official/cv/yolov3_resnet18/export.py @@ -32,7 +32,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU" help="device target") args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == "__main__": config = ConfigYOLOV3ResNet18() diff --git a/model_zoo/official/cv/yolov4/export.py b/model_zoo/official/cv/yolov4/export.py index 6b68bd9bbd6..4743672d447 100644 --- a/model_zoo/official/cv/yolov4/export.py +++ b/model_zoo/official/cv/yolov4/export.py @@ -32,7 +32,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU" help="device target") args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == "__main__": ts_shape = args.testing_shape diff --git a/model_zoo/official/gnn/bgcf/export.py b/model_zoo/official/gnn/bgcf/export.py index 40309725ea3..4098919795e 100644 --- a/model_zoo/official/gnn/bgcf/export.py +++ b/model_zoo/official/gnn/bgcf/export.py @@ -36,7 +36,9 @@ parser.add_argument("--gnew_neighs", type=int, default=20, help="num of sampling parser.add_argument("--activation", type=str, default="tanh", choices=["relu", "tanh"], help="activation function") args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == "__main__": num_user, num_item = 7068, 3570 diff --git a/model_zoo/official/gnn/gat/export.py b/model_zoo/official/gnn/gat/export.py index 0a7ce00bb90..d2b064539c8 100644 --- a/model_zoo/official/gnn/gat/export.py +++ b/model_zoo/official/gnn/gat/export.py @@ -31,7 +31,9 @@ parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)") args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == "__main__": diff --git a/model_zoo/official/gnn/gcn/export.py b/model_zoo/official/gnn/gcn/export.py index 4d1a4296f46..aa91b622bdd 100644 --- a/model_zoo/official/gnn/gcn/export.py +++ b/model_zoo/official/gnn/gcn/export.py @@ -31,7 +31,9 @@ parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)") args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == "__main__": config = ConfigGCN() diff --git a/model_zoo/official/nlp/bert/export.py b/model_zoo/official/nlp/bert/export.py index 5a34edd703a..ec2ba210b67 100644 --- a/model_zoo/official/nlp/bert/export.py +++ b/model_zoo/official/nlp/bert/export.py @@ -39,7 +39,9 @@ parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)") args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) label_list = [] with open(args.label_file_path) as f: diff --git a/model_zoo/official/nlp/mass/export.py b/model_zoo/official/nlp/mass/export.py index fb92fc594ed..2aa72487b16 100644 --- a/model_zoo/official/nlp/mass/export.py +++ b/model_zoo/official/nlp/mass/export.py @@ -36,7 +36,9 @@ parser.add_argument('--gigaword_infer_config', type=str, required=True, help='gi parser.add_argument('--vocab_file', type=str, required=True, help='vocabulary file') args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) def get_config(config_file): tfm_config = TransformerConfig.from_json_file(config_file) diff --git a/model_zoo/official/nlp/textcnn/export.py b/model_zoo/official/nlp/textcnn/export.py index fa08c9ca714..6404d2f481e 100644 --- a/model_zoo/official/nlp/textcnn/export.py +++ b/model_zoo/official/nlp/textcnn/export.py @@ -37,7 +37,9 @@ parser.add_argument('--dataset', type=str, default='MR', choices=['MR', 'SUBJ', args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == '__main__': diff --git a/model_zoo/official/nlp/tinybert/export.py b/model_zoo/official/nlp/tinybert/export.py index ba1d914c89c..50aeeca37bd 100644 --- a/model_zoo/official/nlp/tinybert/export.py +++ b/model_zoo/official/nlp/tinybert/export.py @@ -34,7 +34,9 @@ parser.add_argument("--device_target", type=str, default="Ascend", parser.add_argument('--task_name', type=str, default='SST-2', choices=['SST-2', 'QNLI', 'MNLI'], help='task name') args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) DEFAULT_NUM_LABELS = 2 DEFAULT_SEQ_LENGTH = 128 diff --git a/model_zoo/official/nlp/transformer/export.py b/model_zoo/official/nlp/transformer/export.py index 6cc998f7d62..40bd3d50944 100644 --- a/model_zoo/official/nlp/transformer/export.py +++ b/model_zoo/official/nlp/transformer/export.py @@ -33,7 +33,9 @@ parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)") args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == '__main__': tfm_model = TransformerModel(config=transformer_net_cfg, is_training=False, use_one_hot_embeddings=False) diff --git a/model_zoo/official/recommend/deepfm/export.py b/model_zoo/official/recommend/deepfm/export.py index 230bafa75b3..c981020a94c 100644 --- a/model_zoo/official/recommend/deepfm/export.py +++ b/model_zoo/official/recommend/deepfm/export.py @@ -32,7 +32,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU" help="device target") args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == "__main__": data_config = DataConfig() diff --git a/model_zoo/official/recommend/ncf/export.py b/model_zoo/official/recommend/ncf/export.py index 10f85039692..2e7b4f289d5 100644 --- a/model_zoo/official/recommend/ncf/export.py +++ b/model_zoo/official/recommend/ncf/export.py @@ -32,7 +32,9 @@ parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)") args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == "__main__": topk = rconst.TOP_K diff --git a/model_zoo/official/recommend/wide_and_deep/export.py b/model_zoo/official/recommend/wide_and_deep/export.py index 58da4cb8899..4ca956ef62f 100644 --- a/model_zoo/official/recommend/wide_and_deep/export.py +++ b/model_zoo/official/recommend/wide_and_deep/export.py @@ -32,7 +32,9 @@ parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)") args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) +if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) if __name__ == '__main__': widedeep_config = WideDeepConfig()