forked from mindspore-Ecosystem/mindspore
!16574 [GraphKernel] bert large enable graph kernel on GPU.
From: @chenlei_autodiff Reviewed-by: @gaoxiong1,@ckey_dou Signed-off-by: @ckey_dou
This commit is contained in:
commit
5b673684e3
|
@ -129,21 +129,24 @@ def _get_optimizer(args_opt, network):
|
|||
def _auto_enable_graph_kernel(device_target, graph_kernel_mode):
|
||||
"""Judge whether is suitable to enable graph kernel."""
|
||||
return graph_kernel_mode in ("auto", "true") and device_target == 'GPU' and \
|
||||
cfg.bert_network == 'base' and cfg.optimizer == 'AdamWeightDecay'
|
||||
cfg.bert_network in ('base', 'large') and cfg.optimizer == 'AdamWeightDecay'
|
||||
|
||||
|
||||
def _set_graph_kernel_context(device_target, enable_graph_kernel, is_auto_enable_graph_kernel):
|
||||
if enable_graph_kernel == "true" or is_auto_enable_graph_kernel:
|
||||
if device_target == 'GPU':
|
||||
context.set_context(enable_graph_kernel=True,
|
||||
graph_kernel_flags="--enable_stitch_fusion=false --enable_parallel_fusion=true")
|
||||
if cfg.bert_network == 'base':
|
||||
context.set_context(enable_graph_kernel=True,
|
||||
graph_kernel_flags="--enable_stitch_fusion=false --enable_parallel_fusion=true")
|
||||
else:
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
else:
|
||||
logger.warning('Graph kernel only supports GPU back-end now, run with graph kernel off.')
|
||||
|
||||
|
||||
def _check_compute_type(args_opt, is_auto_enable_graph_kernel):
|
||||
if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32 and \
|
||||
not is_auto_enable_graph_kernel:
|
||||
not is_auto_enable_graph_kernel and cfg.bert_network != 'base':
|
||||
warning_message = 'Gpu only support fp32 temporarily, run with fp32.'
|
||||
bert_net_cfg.compute_type = mstype.float32
|
||||
if args_opt.enable_lossscale == "true":
|
||||
|
|
Loading…
Reference in New Issue