forked from mindspore-Ecosystem/mindspore
fix GPU device_id bug
This commit is contained in:
parent
76bd0f1245
commit
6b858480c8
|
@ -38,7 +38,9 @@ parser.add_argument("--file_name", type=str, default="alexnet", help="output fil
|
||||||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
|
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
|
||||||
args_opt = parser.parse_args()
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||||
|
if args_opt.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args_opt.device_id)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
if args_opt.dataset_name == 'cifar10':
|
if args_opt.dataset_name == 'cifar10':
|
||||||
|
|
|
@ -33,7 +33,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
||||||
help="device target")
|
help="device target")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
config = ConfigCenterface()
|
config = ConfigCenterface()
|
||||||
|
|
|
@ -31,7 +31,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
||||||
parser.add_argument("--ckpt_file", type=str, default="./ckpts/cnn_ctc.ckpt", help="CNN&CTC ckpt file.")
|
parser.add_argument("--ckpt_file", type=str, default="./ckpts/cnn_ctc.ckpt", help="CNN&CTC ckpt file.")
|
||||||
args_opt = parser.parse_args()
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||||
|
if args_opt.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args_opt.device_id)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cfg = Config_CNNCTC()
|
cfg = Config_CNNCTC()
|
||||||
|
|
|
@ -34,7 +34,9 @@ parser.add_argument('--model', type=str.lower, default='deeplab_v3_s8', choices=
|
||||||
parser.add_argument('--num_classes', type=int, default=21, help='the number of classes (Default: 21)')
|
parser.add_argument('--num_classes', type=int, default=21, help='the number of classes (Default: 21)')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
if args.model == 'deeplab_v3_s16':
|
if args.model == 'deeplab_v3_s16':
|
||||||
|
|
|
@ -33,7 +33,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
||||||
help="device target")
|
help="device target")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
network = DenseNet121(config.num_classes)
|
network = DenseNet121(config.num_classes)
|
||||||
|
|
|
@ -32,7 +32,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
||||||
help="device target")
|
help="device target")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if args.device_target != "GPU":
|
if args.device_target != "GPU":
|
||||||
|
|
|
@ -31,7 +31,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
||||||
parser.add_argument('--ckpt_file', type=str, default='', help='fasterrcnn ckpt file.')
|
parser.add_argument('--ckpt_file', type=str, default='', help='fasterrcnn ckpt file.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
net = FasterRcnn_Infer(config=config)
|
net = FasterRcnn_Infer(config=config)
|
||||||
|
|
|
@ -37,7 +37,9 @@ parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['ima
|
||||||
help='dataset name.')
|
help='dataset name.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
if args.dataset_name == 'cifar10':
|
if args.dataset_name == 'cifar10':
|
||||||
|
|
|
@ -33,7 +33,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
||||||
help="device target")
|
help="device target")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
net = InceptionV3(num_classes=cfg.num_classes, is_training=False)
|
net = InceptionV3(num_classes=cfg.num_classes, is_training=False)
|
||||||
|
|
|
@ -34,7 +34,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
||||||
help="device target")
|
help="device target")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
net = Inceptionv4(classes=config.num_classes)
|
net = Inceptionv4(classes=config.num_classes)
|
||||||
|
|
|
@ -33,7 +33,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
||||||
help="device target")
|
help="device target")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,9 @@ parser.add_argument('--device_target', type=str, default="Ascend",
|
||||||
choices=['Ascend', 'GPU', 'CPU'], help='device target (default: Ascend)')
|
choices=['Ascend', 'GPU', 'CPU'], help='device target (default: Ascend)')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
net = MaskRcnn_Infer(config=config)
|
net = MaskRcnn_Infer(config=config)
|
||||||
|
|
|
@ -31,7 +31,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
||||||
help="device target")
|
help="device target")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
config.test_batch_size = args.batch_size
|
config.test_batch_size = args.batch_size
|
||||||
|
|
|
@ -34,7 +34,9 @@ args = parser.parse_args()
|
||||||
args.is_training = False
|
args.is_training = False
|
||||||
args.run_distribute = False
|
args.run_distribute = False
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform)
|
||||||
|
if args.platform == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
cfg = set_config(args)
|
cfg = set_config(args)
|
||||||
|
|
|
@ -34,7 +34,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
||||||
help="device target")
|
help="device target")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
net = NASNetAMobile(num_classes=cfg.num_classes, is_training=False)
|
net = NASNetAMobile(num_classes=cfg.num_classes, is_training=False)
|
||||||
|
|
|
@ -34,7 +34,9 @@ parser.add_argument("--device_target", type=str, default="Ascend",
|
||||||
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
net = ETSNet(config)
|
net = ETSNet(config)
|
||||||
|
|
|
@ -38,7 +38,9 @@ parser.add_argument("--device_target", type=str, default="Ascend",
|
||||||
choices=["Ascend", "GPU", "CPU"], help="device target(default: Ascend)")
|
choices=["Ascend", "GPU", "CPU"], help="device target(default: Ascend)")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,9 @@ parser.add_argument("--device_target", type=str, default="Ascend",
|
||||||
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
net = get_network(num_classes=config.num_classes, platform=args.device_target)
|
net = get_network(num_classes=config.num_classes, platform=args.device_target)
|
||||||
|
|
|
@ -37,7 +37,9 @@ parser.add_argument('--model_size', type=str, default='2.0x', choices=['2.0x', '
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
|
|
@ -34,7 +34,9 @@ parser.add_argument("--device_target", type=str, default="GPU",
|
||||||
choices=["Ascend", "GPU", "CPU"], help="device where the code will be implemented (default: GPU)")
|
choices=["Ascend", "GPU", "CPU"], help="device where the code will be implemented (default: GPU)")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
if args.device_target != 'GPU':
|
if args.device_target != 'GPU':
|
||||||
|
|
|
@ -45,7 +45,9 @@ if args.dataset == "cifar10":
|
||||||
else:
|
else:
|
||||||
num_classes = 1000
|
num_classes = 1000
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
net = squeezenet(num_classes=num_classes)
|
net = squeezenet(num_classes=num_classes)
|
||||||
|
|
|
@ -32,7 +32,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
||||||
help="device target")
|
help="device target")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
if config.model == "ssd300":
|
if config.model == "ssd300":
|
||||||
|
|
|
@ -33,7 +33,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
||||||
help="device target")
|
help="device target")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
net = UNet(n_channels=cfg["num_channels"], n_classes=cfg["num_classes"])
|
net = UNet(n_channels=cfg["num_channels"], n_classes=cfg["num_classes"])
|
||||||
|
|
|
@ -46,7 +46,9 @@ args.batch_norm = cfg.batch_norm
|
||||||
args.has_dropout = cfg.has_dropout
|
args.has_dropout = cfg.has_dropout
|
||||||
args.image_size = list(map(int, cfg.image_size.split(',')))
|
args.image_size = list(map(int, cfg.image_size.split(',')))
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
if args.dataset == "cifar10":
|
if args.dataset == "cifar10":
|
||||||
|
|
|
@ -30,7 +30,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
||||||
help="device target")
|
help="device target")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
captcha_width = config.captcha_width
|
captcha_width = config.captcha_width
|
||||||
|
|
|
@ -33,7 +33,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
||||||
help="device target")
|
help="device target")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# define net
|
# define net
|
||||||
|
|
|
@ -32,7 +32,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
||||||
help="device target")
|
help="device target")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
network = YOLOV3DarkNet53(is_training=False)
|
network = YOLOV3DarkNet53(is_training=False)
|
||||||
|
|
|
@ -32,7 +32,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
||||||
help="device target")
|
help="device target")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
config = ConfigYOLOV3ResNet18()
|
config = ConfigYOLOV3ResNet18()
|
||||||
|
|
|
@ -32,7 +32,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
||||||
help="device target")
|
help="device target")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
ts_shape = args.testing_shape
|
ts_shape = args.testing_shape
|
||||||
|
|
|
@ -36,7 +36,9 @@ parser.add_argument("--gnew_neighs", type=int, default=20, help="num of sampling
|
||||||
parser.add_argument("--activation", type=str, default="tanh", choices=["relu", "tanh"], help="activation function")
|
parser.add_argument("--activation", type=str, default="tanh", choices=["relu", "tanh"], help="activation function")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
num_user, num_item = 7068, 3570
|
num_user, num_item = 7068, 3570
|
||||||
|
|
|
@ -31,7 +31,9 @@ parser.add_argument("--device_target", type=str, default="Ascend",
|
||||||
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,9 @@ parser.add_argument("--device_target", type=str, default="Ascend",
|
||||||
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
config = ConfigGCN()
|
config = ConfigGCN()
|
||||||
|
|
|
@ -39,7 +39,9 @@ parser.add_argument("--device_target", type=str, default="Ascend",
|
||||||
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
label_list = []
|
label_list = []
|
||||||
with open(args.label_file_path) as f:
|
with open(args.label_file_path) as f:
|
||||||
|
|
|
@ -36,7 +36,9 @@ parser.add_argument('--gigaword_infer_config', type=str, required=True, help='gi
|
||||||
parser.add_argument('--vocab_file', type=str, required=True, help='vocabulary file')
|
parser.add_argument('--vocab_file', type=str, required=True, help='vocabulary file')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
def get_config(config_file):
|
def get_config(config_file):
|
||||||
tfm_config = TransformerConfig.from_json_file(config_file)
|
tfm_config = TransformerConfig.from_json_file(config_file)
|
||||||
|
|
|
@ -37,7 +37,9 @@ parser.add_argument('--dataset', type=str, default='MR', choices=['MR', 'SUBJ',
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
|
|
@ -34,7 +34,9 @@ parser.add_argument("--device_target", type=str, default="Ascend",
|
||||||
parser.add_argument('--task_name', type=str, default='SST-2', choices=['SST-2', 'QNLI', 'MNLI'], help='task name')
|
parser.add_argument('--task_name', type=str, default='SST-2', choices=['SST-2', 'QNLI', 'MNLI'], help='task name')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
DEFAULT_NUM_LABELS = 2
|
DEFAULT_NUM_LABELS = 2
|
||||||
DEFAULT_SEQ_LENGTH = 128
|
DEFAULT_SEQ_LENGTH = 128
|
||||||
|
|
|
@ -33,7 +33,9 @@ parser.add_argument("--device_target", type=str, default="Ascend",
|
||||||
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tfm_model = TransformerModel(config=transformer_net_cfg, is_training=False, use_one_hot_embeddings=False)
|
tfm_model = TransformerModel(config=transformer_net_cfg, is_training=False, use_one_hot_embeddings=False)
|
||||||
|
|
|
@ -32,7 +32,9 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
||||||
help="device target")
|
help="device target")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
data_config = DataConfig()
|
data_config = DataConfig()
|
||||||
|
|
|
@ -32,7 +32,9 @@ parser.add_argument("--device_target", type=str, default="Ascend",
|
||||||
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
topk = rconst.TOP_K
|
topk = rconst.TOP_K
|
||||||
|
|
|
@ -32,7 +32,9 @@ parser.add_argument("--device_target", type=str, default="Ascend",
|
||||||
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
widedeep_config = WideDeepConfig()
|
widedeep_config = WideDeepConfig()
|
||||||
|
|
Loading…
Reference in New Issue