forked from mindspore-Ecosystem/mindspore
!7172 bugfix tinybert
Merge pull request !7172 from yoonlee666/tinybertbugfix
This commit is contained in:
commit
b1ac85bde9
|
@ -112,7 +112,12 @@ def run_predistill():
|
|||
run predistill
|
||||
"""
|
||||
cfg = phase1_cfg
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
if args_opt.device_target == "Ascend":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
elif args_opt.device_target == "GPU":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||
else:
|
||||
raise Exception("Target error, GPU or Ascend is supported.")
|
||||
context.set_context(reserve_class_name_in_scope=False)
|
||||
load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path
|
||||
load_student_checkpoint_path = args_opt.load_gd_ckpt_path
|
||||
|
@ -265,7 +270,12 @@ def do_eval_standalone():
|
|||
ckpt_file = args_opt.load_td1_ckpt_path
|
||||
if ckpt_file == '':
|
||||
raise ValueError("Student ckpt file should not be None")
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
if args_opt.device_target == "Ascend":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
elif args_opt.device_target == "GPU":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||
else:
|
||||
raise Exception("Target error, GPU or Ascend is supported.")
|
||||
eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student")
|
||||
param_dict = load_checkpoint(ckpt_file)
|
||||
new_param_dict = {}
|
||||
|
|
Loading…
Reference in New Issue