diff --git a/model_zoo/official/nlp/bert/run_pretrain.py b/model_zoo/official/nlp/bert/run_pretrain.py index 87564388c48..2eef5b0a503 100644 --- a/model_zoo/official/nlp/bert/run_pretrain.py +++ b/model_zoo/official/nlp/bert/run_pretrain.py @@ -129,21 +129,24 @@ def _get_optimizer(args_opt, network): def _auto_enable_graph_kernel(device_target, graph_kernel_mode): """Judge whether is suitable to enable graph kernel.""" return graph_kernel_mode in ("auto", "true") and device_target == 'GPU' and \ - cfg.bert_network == 'base' and cfg.optimizer == 'AdamWeightDecay' + cfg.bert_network in ('base', 'large') and cfg.optimizer == 'AdamWeightDecay' def _set_graph_kernel_context(device_target, enable_graph_kernel, is_auto_enable_graph_kernel): if enable_graph_kernel == "true" or is_auto_enable_graph_kernel: if device_target == 'GPU': - context.set_context(enable_graph_kernel=True, - graph_kernel_flags="--enable_stitch_fusion=false --enable_parallel_fusion=true") + if cfg.bert_network == 'base': + context.set_context(enable_graph_kernel=True, + graph_kernel_flags="--enable_stitch_fusion=false --enable_parallel_fusion=true") + else: + context.set_context(enable_graph_kernel=True) else: logger.warning('Graph kernel only supports GPU back-end now, run with graph kernel off.') def _check_compute_type(args_opt, is_auto_enable_graph_kernel): if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32 and \ - not is_auto_enable_graph_kernel: + not is_auto_enable_graph_kernel and cfg.bert_network != 'base': warning_message = 'Gpu only support fp32 temporarily, run with fp32.' bert_net_cfg.compute_type = mstype.float32 if args_opt.enable_lossscale == "true":