delete device id for gpu

This commit is contained in:
yoonlee666 2020-10-09 10:54:53 +08:00
parent d7b7ba3797
commit 8fa83cca87
2 changed files with 15 additions and 2 deletions

View File

@ -62,7 +62,13 @@ def run_general_distill():
help="dataset type tfrecord/mindrecord, default is tfrecord")
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
if args_opt.device_target == "Ascend":
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
elif args_opt.device_target == "GPU":
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
else:
raise Exception("Target error, GPU or Ascend is supported.")
context.set_context(reserve_class_name_in_scope=False)
context.set_context(variable_memory_max_size="30GB")

View File

@ -184,7 +184,14 @@ def run_task_distill(ckpt_file):
if ckpt_file == '':
raise ValueError("Student ckpt file should not be None")
cfg = phase2_cfg
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
if args_opt.device_target == "Ascend":
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
elif args_opt.device_target == "GPU":
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
else:
raise Exception("Target error, GPU or Ascend is supported.")
load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path
load_student_checkpoint_path = ckpt_file
netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path,