!16574 [GraphKernel] bert large enable graph kernel on GPU.

From: @chenlei_autodiff
Reviewed-by: @gaoxiong1,@ckey_dou
Signed-off-by: @ckey_dou
This commit is contained in:
mindspore-ci-bot 2021-05-20 11:19:24 +08:00 committed by Gitee
commit 5b673684e3
1 changed files with 7 additions and 4 deletions

View File

@ -129,21 +129,24 @@ def _get_optimizer(args_opt, network):
def _auto_enable_graph_kernel(device_target, graph_kernel_mode):
"""Judge whether is suitable to enable graph kernel."""
return graph_kernel_mode in ("auto", "true") and device_target == 'GPU' and \
cfg.bert_network == 'base' and cfg.optimizer == 'AdamWeightDecay'
cfg.bert_network in ('base', 'large') and cfg.optimizer == 'AdamWeightDecay'
def _set_graph_kernel_context(device_target, enable_graph_kernel, is_auto_enable_graph_kernel):
if enable_graph_kernel == "true" or is_auto_enable_graph_kernel:
if device_target == 'GPU':
context.set_context(enable_graph_kernel=True,
graph_kernel_flags="--enable_stitch_fusion=false --enable_parallel_fusion=true")
if cfg.bert_network == 'base':
context.set_context(enable_graph_kernel=True,
graph_kernel_flags="--enable_stitch_fusion=false --enable_parallel_fusion=true")
else:
context.set_context(enable_graph_kernel=True)
else:
logger.warning('Graph kernel only supports GPU back-end now, run with graph kernel off.')
def _check_compute_type(args_opt, is_auto_enable_graph_kernel):
if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32 and \
not is_auto_enable_graph_kernel:
not is_auto_enable_graph_kernel and cfg.bert_network != 'base':
warning_message = 'Gpu only support fp32 temporarily, run with fp32.'
bert_net_cfg.compute_type = mstype.float32
if args_opt.enable_lossscale == "true":