forked from mindspore-Ecosystem/mindspore
!17156 Enable Graph Kernel for TinyBert on GPU
From: @jiaoy1224 Reviewed-by: @ckey_dou,@gaoxiong1 Signed-off-by: @ckey_dou
This commit is contained in:
commit
ad1ea03779
|
@ -94,6 +94,7 @@ def run_general_distill():
|
|||
|
||||
enable_loss_scale = True
|
||||
if args_opt.device_target == "GPU":
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
if bert_student_net_cfg.compute_type != mstype.float32:
|
||||
logger.warning('Compute about the student only support float32 temporarily, run with float32.')
|
||||
bert_student_net_cfg.compute_type = mstype.float32
|
||||
|
|
|
@ -339,6 +339,7 @@ if __name__ == '__main__':
|
|||
context.set_context(device_id=args_opt.device_id)
|
||||
enable_loss_scale = True
|
||||
if args_opt.device_target == "GPU":
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
if td_student_net_cfg.compute_type != mstype.float32:
|
||||
logger.warning('Compute about the student only support float32 temporarily, run with float32.')
|
||||
td_student_net_cfg.compute_type = mstype.float32
|
||||
|
|
Loading…
Reference in New Issue