!17156 Enable Graph Kernel for TinyBert on GPU

From: @jiaoy1224
Reviewed-by: @ckey_dou,@gaoxiong1
Signed-off-by: @ckey_dou
This commit is contained in:
mindspore-ci-bot 2021-05-29 09:35:22 +08:00 committed by Gitee
commit ad1ea03779
2 changed files with 2 additions and 0 deletions

View File

@ -94,6 +94,7 @@ def run_general_distill():
enable_loss_scale = True
if args_opt.device_target == "GPU":
context.set_context(enable_graph_kernel=True)
if bert_student_net_cfg.compute_type != mstype.float32:
logger.warning('Compute about the student only support float32 temporarily, run with float32.')
bert_student_net_cfg.compute_type = mstype.float32

View File

@ -339,6 +339,7 @@ if __name__ == '__main__':
context.set_context(device_id=args_opt.device_id)
enable_loss_scale = True
if args_opt.device_target == "GPU":
context.set_context(enable_graph_kernel=True)
if td_student_net_cfg.compute_type != mstype.float32:
logger.warning('Compute about the student only support float32 temporarily, run with float32.')
td_student_net_cfg.compute_type = mstype.float32