forked from mindspore-Ecosystem/mindspore
fix GPU device_id bug
This commit is contained in:
parent
76bd0f1245
commit
6b858480c8
|
@ -38,7 +38,9 @@ parser.add_argument("--file_name", type=str, default="alexnet", help="output fil
|
|||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||
if args_opt.device_target == "Ascend":
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
if args_opt.dataset_name == 'cifar10':
|
||||
|
|
|
@ -33,7 +33,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
|||
help="device target")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
config = ConfigCenterface()
|
||||
|
|
|
@ -31,7 +31,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
|||
parser.add_argument("--ckpt_file", type=str, default="./ckpts/cnn_ctc.ckpt", help="CNN&CTC ckpt file.")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||
if args_opt.device_target == "Ascend":
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
cfg = Config_CNNCTC()
|
||||
|
|
|
@ -34,7 +34,9 @@ parser.add_argument('--model', type=str.lower, default='deeplab_v3_s8', choices=
|
|||
parser.add_argument('--num_classes', type=int, default=21, help='the number of classes (Default: 21)')
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
if args.model == 'deeplab_v3_s16':
|
||||
|
|
|
@ -33,7 +33,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
|||
help="device target")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
network = DenseNet121(config.num_classes)
|
||||
|
|
|
@ -32,7 +32,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
|||
help="device target")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.device_target != "GPU":
|
||||
|
|
|
@ -31,7 +31,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
|||
parser.add_argument('--ckpt_file', type=str, default='', help='fasterrcnn ckpt file.')
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = FasterRcnn_Infer(config=config)
|
||||
|
|
|
@ -37,7 +37,9 @@ parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['ima
|
|||
help='dataset name.')
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
if args.dataset_name == 'cifar10':
|
||||
|
|
|
@ -33,7 +33,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
|||
help="device target")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = InceptionV3(num_classes=cfg.num_classes, is_training=False)
|
||||
|
|
|
@ -34,7 +34,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
|||
help="device target")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = Inceptionv4(classes=config.num_classes)
|
||||
|
|
|
@ -33,7 +33,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
|||
help="device target")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
|
|
@ -31,7 +31,9 @@ parser.add_argument('--device_target', type=str, default="Ascend",
|
|||
choices=['Ascend', 'GPU', 'CPU'], help='device target (default: Ascend)')
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = MaskRcnn_Infer(config=config)
|
||||
|
|
|
@ -31,7 +31,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
|||
help="device target")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
config.test_batch_size = args.batch_size
|
||||
|
|
|
@ -34,7 +34,9 @@ args = parser.parse_args()
|
|||
args.is_training = False
|
||||
args.run_distribute = False
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform)
|
||||
if args.platform == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
cfg = set_config(args)
|
||||
|
|
|
@ -34,7 +34,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
|||
help="device target")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = NASNetAMobile(num_classes=cfg.num_classes, is_training=False)
|
||||
|
|
|
@ -34,7 +34,9 @@ parser.add_argument("--device_target", type=str, default="Ascend",
|
|||
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = ETSNet(config)
|
||||
|
|
|
@ -38,7 +38,9 @@ parser.add_argument("--device_target", type=str, default="Ascend",
|
|||
choices=["Ascend", "GPU", "CPU"], help="device target(default: Ascend)")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
|
|
|
@ -33,7 +33,9 @@ parser.add_argument("--device_target", type=str, default="Ascend",
|
|||
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = get_network(num_classes=config.num_classes, platform=args.device_target)
|
||||
|
|
|
@ -37,7 +37,9 @@ parser.add_argument('--model_size', type=str, default='2.0x', choices=['2.0x', '
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
|
|
|
@ -34,7 +34,9 @@ parser.add_argument("--device_target", type=str, default="GPU",
|
|||
choices=["Ascend", "GPU", "CPU"], help="device where the code will be implemented (default: GPU)")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
if args.device_target != 'GPU':
|
||||
|
|
|
@ -45,7 +45,9 @@ if args.dataset == "cifar10":
|
|||
else:
|
||||
num_classes = 1000
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = squeezenet(num_classes=num_classes)
|
||||
|
|
|
@ -32,7 +32,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
|||
help="device target")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
if config.model == "ssd300":
|
||||
|
|
|
@ -33,7 +33,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
|||
help="device target")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
net = UNet(n_channels=cfg["num_channels"], n_classes=cfg["num_classes"])
|
||||
|
|
|
@ -46,7 +46,9 @@ args.batch_norm = cfg.batch_norm
|
|||
args.has_dropout = cfg.has_dropout
|
||||
args.image_size = list(map(int, cfg.image_size.split(',')))
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
if args.dataset == "cifar10":
|
||||
|
|
|
@ -30,7 +30,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
|||
help="device target")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
captcha_width = config.captcha_width
|
||||
|
|
|
@ -33,7 +33,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
|||
help="device target")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# define net
|
||||
|
|
|
@ -32,7 +32,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
|||
help="device target")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
network = YOLOV3DarkNet53(is_training=False)
|
||||
|
|
|
@ -32,7 +32,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
|||
help="device target")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = ConfigYOLOV3ResNet18()
|
||||
|
|
|
@ -32,7 +32,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
|||
help="device target")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
ts_shape = args.testing_shape
|
||||
|
|
|
@ -36,7 +36,9 @@ parser.add_argument("--gnew_neighs", type=int, default=20, help="num of sampling
|
|||
parser.add_argument("--activation", type=str, default="tanh", choices=["relu", "tanh"], help="activation function")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
num_user, num_item = 7068, 3570
|
||||
|
|
|
@ -31,7 +31,9 @@ parser.add_argument("--device_target", type=str, default="Ascend",
|
|||
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
|
|
@ -31,7 +31,9 @@ parser.add_argument("--device_target", type=str, default="Ascend",
|
|||
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = ConfigGCN()
|
||||
|
|
|
@ -39,7 +39,9 @@ parser.add_argument("--device_target", type=str, default="Ascend",
|
|||
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
label_list = []
|
||||
with open(args.label_file_path) as f:
|
||||
|
|
|
@ -36,7 +36,9 @@ parser.add_argument('--gigaword_infer_config', type=str, required=True, help='gi
|
|||
parser.add_argument('--vocab_file', type=str, required=True, help='vocabulary file')
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
def get_config(config_file):
|
||||
tfm_config = TransformerConfig.from_json_file(config_file)
|
||||
|
|
|
@ -37,7 +37,9 @@ parser.add_argument('--dataset', type=str, default='MR', choices=['MR', 'SUBJ',
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
|
|
|
@ -34,7 +34,9 @@ parser.add_argument("--device_target", type=str, default="Ascend",
|
|||
parser.add_argument('--task_name', type=str, default='SST-2', choices=['SST-2', 'QNLI', 'MNLI'], help='task name')
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
DEFAULT_NUM_LABELS = 2
|
||||
DEFAULT_SEQ_LENGTH = 128
|
||||
|
|
|
@ -33,7 +33,9 @@ parser.add_argument("--device_target", type=str, default="Ascend",
|
|||
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
tfm_model = TransformerModel(config=transformer_net_cfg, is_training=False, use_one_hot_embeddings=False)
|
||||
|
|
|
@ -32,7 +32,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
|||
help="device target")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_config = DataConfig()
|
||||
|
|
|
@ -32,7 +32,9 @@ parser.add_argument("--device_target", type=str, default="Ascend",
|
|||
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
topk = rconst.TOP_K
|
||||
|
|
|
@ -32,7 +32,9 @@ parser.add_argument("--device_target", type=str, default="Ascend",
|
|||
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
widedeep_config = WideDeepConfig()
|
||||
|
|
Loading…
Reference in New Issue