[Bert][Gpu]Sync modify of bert script from r1.1 to master
This commit is contained in:
parent
01a0cdf5f0
commit
53d4510ea6
|
@ -44,11 +44,18 @@ _current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
|
||||||
def _set_bert_all_reduce_split():
|
def _set_bert_all_reduce_split():
|
||||||
"""set bert all_reduce fusion split, support num_hidden_layers is 12 and 24."""
|
"""set bert all_reduce fusion split, support num_hidden_layers is 12 and 24."""
|
||||||
|
device_target = context.get_context('device_target')
|
||||||
|
enable_graph_kernel = context.get_context('enable_graph_kernel')
|
||||||
|
device_num = context.get_auto_parallel_context('device_num')
|
||||||
if bert_net_cfg.num_hidden_layers == 12:
|
if bert_net_cfg.num_hidden_layers == 12:
|
||||||
if bert_net_cfg.use_relative_positions:
|
if bert_net_cfg.use_relative_positions:
|
||||||
context.set_auto_parallel_context(all_reduce_fusion_config=[29, 58, 87, 116, 145, 174, 203, 217])
|
context.set_auto_parallel_context(all_reduce_fusion_config=[29, 58, 87, 116, 145, 174, 203, 217])
|
||||||
else:
|
else:
|
||||||
context.set_auto_parallel_context(all_reduce_fusion_config=[28, 55, 82, 109, 136, 163, 190, 205])
|
context.set_auto_parallel_context(all_reduce_fusion_config=[28, 55, 82, 109, 136, 163, 190, 205])
|
||||||
|
if device_target == 'GPU' and enable_graph_kernel and device_num == 8:
|
||||||
|
context.set_auto_parallel_context(all_reduce_fusion_config=[180, 205])
|
||||||
|
elif device_target == 'GPU' and enable_graph_kernel and device_num == 16:
|
||||||
|
context.set_auto_parallel_context(all_reduce_fusion_config=[120, 205])
|
||||||
elif bert_net_cfg.num_hidden_layers == 24:
|
elif bert_net_cfg.num_hidden_layers == 24:
|
||||||
if bert_net_cfg.use_relative_positions:
|
if bert_net_cfg.use_relative_positions:
|
||||||
context.set_auto_parallel_context(all_reduce_fusion_config=[30, 90, 150, 210, 270, 330, 390, 421])
|
context.set_auto_parallel_context(all_reduce_fusion_config=[30, 90, 150, 210, 270, 330, 390, 421])
|
||||||
|
@ -99,8 +106,7 @@ def _get_optimizer(args_opt, network):
|
||||||
def _auto_enable_graph_kernel(device_target, graph_kernel_mode):
|
def _auto_enable_graph_kernel(device_target, graph_kernel_mode):
|
||||||
"""Judge whether is suitable to enable graph kernel."""
|
"""Judge whether is suitable to enable graph kernel."""
|
||||||
return graph_kernel_mode in ("auto", "true") and device_target == 'GPU' and \
|
return graph_kernel_mode in ("auto", "true") and device_target == 'GPU' and \
|
||||||
cfg.bert_network == 'base' and (cfg.batch_size == 32 or cfg.batch_size == 64) and \
|
cfg.bert_network == 'base' and cfg.optimizer == 'AdamWeightDecay'
|
||||||
cfg.optimizer == 'AdamWeightDecay'
|
|
||||||
|
|
||||||
|
|
||||||
def _set_graph_kernel_context(device_target, enable_graph_kernel, is_auto_enable_graph_kernel):
|
def _set_graph_kernel_context(device_target, enable_graph_kernel, is_auto_enable_graph_kernel):
|
||||||
|
@ -111,10 +117,15 @@ def _set_graph_kernel_context(device_target, enable_graph_kernel, is_auto_enable
|
||||||
logger.warning('Graph kernel only supports GPU back-end now, run with graph kernel off.')
|
logger.warning('Graph kernel only supports GPU back-end now, run with graph kernel off.')
|
||||||
|
|
||||||
|
|
||||||
def _check_compute_type(device_target, is_auto_enable_graph_kernel):
|
def _check_compute_type(args_opt, is_auto_enable_graph_kernel):
|
||||||
if device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32 and not is_auto_enable_graph_kernel:
|
if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32 and \
|
||||||
logger.warning('Gpu only support fp32 temporarily, run with fp32.')
|
not is_auto_enable_graph_kernel:
|
||||||
|
warning_message = 'Gpu only support fp32 temporarily, run with fp32.'
|
||||||
bert_net_cfg.compute_type = mstype.float32
|
bert_net_cfg.compute_type = mstype.float32
|
||||||
|
if args_opt.enable_lossscale == "true":
|
||||||
|
args_opt.enable_lossscale = "false"
|
||||||
|
warning_message = 'Gpu only support fp32 temporarily, run with fp32 and disable lossscale.'
|
||||||
|
logger.warning(warning_message)
|
||||||
|
|
||||||
|
|
||||||
def argparse_init():
|
def argparse_init():
|
||||||
|
@ -160,6 +171,8 @@ def run_pretrain():
|
||||||
args_opt = parser.parse_args()
|
args_opt = parser.parse_args()
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||||
context.set_context(reserve_class_name_in_scope=False)
|
context.set_context(reserve_class_name_in_scope=False)
|
||||||
|
is_auto_enable_graph_kernel = _auto_enable_graph_kernel(args_opt.device_target, args_opt.enable_graph_kernel)
|
||||||
|
_set_graph_kernel_context(args_opt.device_target, args_opt.enable_graph_kernel, is_auto_enable_graph_kernel)
|
||||||
ckpt_save_dir = args_opt.save_checkpoint_path
|
ckpt_save_dir = args_opt.save_checkpoint_path
|
||||||
if args_opt.distribute == "true":
|
if args_opt.distribute == "true":
|
||||||
if args_opt.device_target == 'Ascend':
|
if args_opt.device_target == 'Ascend':
|
||||||
|
@ -175,15 +188,12 @@ def run_pretrain():
|
||||||
context.reset_auto_parallel_context()
|
context.reset_auto_parallel_context()
|
||||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
||||||
device_num=device_num)
|
device_num=device_num)
|
||||||
if args_opt.device_target == 'Ascend':
|
_set_bert_all_reduce_split()
|
||||||
_set_bert_all_reduce_split()
|
|
||||||
else:
|
else:
|
||||||
rank = 0
|
rank = 0
|
||||||
device_num = 1
|
device_num = 1
|
||||||
|
|
||||||
is_auto_enable_graph_kernel = _auto_enable_graph_kernel(args_opt.device_target, args_opt.enable_graph_kernel)
|
_check_compute_type(args_opt, is_auto_enable_graph_kernel)
|
||||||
_set_graph_kernel_context(args_opt.device_target, args_opt.enable_graph_kernel, is_auto_enable_graph_kernel)
|
|
||||||
_check_compute_type(args_opt.device_target, is_auto_enable_graph_kernel)
|
|
||||||
|
|
||||||
if args_opt.accumulation_steps > 1:
|
if args_opt.accumulation_steps > 1:
|
||||||
logger.info("accumulation steps: {}".format(args_opt.accumulation_steps))
|
logger.info("accumulation steps: {}".format(args_opt.accumulation_steps))
|
||||||
|
|
|
@ -32,7 +32,7 @@ mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-st
|
||||||
--distribute="true" \
|
--distribute="true" \
|
||||||
--epoch_size=$EPOCH_SIZE \
|
--epoch_size=$EPOCH_SIZE \
|
||||||
--enable_save_ckpt="true" \
|
--enable_save_ckpt="true" \
|
||||||
--enable_lossscale="false" \
|
--enable_lossscale="true" \
|
||||||
--do_shuffle="true" \
|
--do_shuffle="true" \
|
||||||
--enable_data_sink="true" \
|
--enable_data_sink="true" \
|
||||||
--data_sink_steps=20 \
|
--data_sink_steps=20 \
|
||||||
|
|
|
@ -36,7 +36,7 @@ python run_pretrain.py \
|
||||||
--distribute="false" \
|
--distribute="false" \
|
||||||
--epoch_size=$EPOCH_SIZE \
|
--epoch_size=$EPOCH_SIZE \
|
||||||
--enable_save_ckpt="true" \
|
--enable_save_ckpt="true" \
|
||||||
--enable_lossscale="false" \
|
--enable_lossscale="true" \
|
||||||
--do_shuffle="true" \
|
--do_shuffle="true" \
|
||||||
--enable_data_sink="true" \
|
--enable_data_sink="true" \
|
||||||
--data_sink_steps=20 \
|
--data_sink_steps=20 \
|
||||||
|
|
Loading…
Reference in New Issue