From 0e85c5c9a4e45a5ca4da2a73ea27afcfbbc37b3f Mon Sep 17 00:00:00 2001 From: yoonlee666 Date: Sat, 10 Oct 2020 21:49:55 +0800 Subject: [PATCH] bugfix tinybert --- .../official/nlp/tinybert/run_task_distill.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/model_zoo/official/nlp/tinybert/run_task_distill.py b/model_zoo/official/nlp/tinybert/run_task_distill.py index 459fd529009..cd35bc5c346 100644 --- a/model_zoo/official/nlp/tinybert/run_task_distill.py +++ b/model_zoo/official/nlp/tinybert/run_task_distill.py @@ -112,7 +112,12 @@ def run_predistill(): run predistill """ cfg = phase1_cfg - context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) + if args_opt.device_target == "Ascend": + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) + elif args_opt.device_target == "GPU": + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) + else: + raise Exception("Target error, GPU or Ascend is supported.") context.set_context(reserve_class_name_in_scope=False) load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path load_student_checkpoint_path = args_opt.load_gd_ckpt_path @@ -265,7 +270,12 @@ def do_eval_standalone(): ckpt_file = args_opt.load_td1_ckpt_path if ckpt_file == '': raise ValueError("Student ckpt file should not be None") - context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) + if args_opt.device_target == "Ascend": + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) + elif args_opt.device_target == "GPU": + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) + else: + raise Exception("Target error, GPU or Ascend is supported.") eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student") param_dict = load_checkpoint(ckpt_file) new_param_dict = {}