forked from OSSInnovation/mindspore
delete device id for gpu
This commit is contained in:
parent
d7b7ba3797
commit
8fa83cca87
|
@ -62,7 +62,13 @@ def run_general_distill():
|
|||
help="dataset type tfrecord/mindrecord, default is tfrecord")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
if args_opt.device_target == "Ascend":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
elif args_opt.device_target == "GPU":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||
else:
|
||||
raise Exception("Target error, GPU or Ascend is supported.")
|
||||
|
||||
context.set_context(reserve_class_name_in_scope=False)
|
||||
context.set_context(variable_memory_max_size="30GB")
|
||||
|
||||
|
|
|
@ -184,7 +184,14 @@ def run_task_distill(ckpt_file):
|
|||
if ckpt_file == '':
|
||||
raise ValueError("Student ckpt file should not be None")
|
||||
cfg = phase2_cfg
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
|
||||
if args_opt.device_target == "Ascend":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
elif args_opt.device_target == "GPU":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||
else:
|
||||
raise Exception("Target error, GPU or Ascend is supported.")
|
||||
|
||||
load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path
|
||||
load_student_checkpoint_path = ckpt_file
|
||||
netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path,
|
||||
|
|
Loading…
Reference in New Issue