update bert config, add large_acc config

This commit is contained in:
chenhaozhe 2021-07-13 20:42:58 +08:00
parent 7001ae7921
commit 76610b48c4
2 changed files with 23 additions and 0 deletions

View File

@ -35,6 +35,7 @@ schema_dir: ''
# ==============================================================================
# pretrain related
batch_size: 32
# Available: [base, nezha, large, large_acc]
bert_network: 'base'
loss_scale_value: 65536
scale_factor: 2
@ -121,6 +122,24 @@ nezha_net_cfg:
# large
large_batch_size: 24
large_net_cfg:
seq_length: 512
vocab_size: 30522
hidden_size: 1024
num_hidden_layers: 24
num_attention_heads: 16
intermediate_size: 4096
hidden_act: "gelu"
hidden_dropout_prob: 0.1
attention_probs_dropout_prob: 0.1
max_position_embeddings: 512
type_vocab_size: 2
initializer_range: 0.02
use_relative_positions: False
dtype: mstype.float32
compute_type: mstype.float16
# Accelerated large network which is only supported in Ascend yet.
large_acc_batch_size: 24
large_acc_net_cfg:
seq_length: 512
vocab_size: 30522
hidden_size: 1024
@ -137,6 +156,7 @@ large_net_cfg:
dtype: mstype.float32
compute_type: mstype.float16
---
# Help description for each configuration
enable_modelarts: "Whether training on modelarts, default: False"

View File

@ -140,6 +140,9 @@ def extra_operations(cfg):
elif cfg.bert_network == 'large':
cfg.batch_size = cfg.large_batch_size
_bert_net_cfg = cfg.large_net_cfg
elif cfg.bert_network == 'large_acc':
cfg.batch_size = cfg.large_acc_batch_size
_bert_net_cfg = cfg.large_acc_net_cfg
else:
pass
cfg.bert_net_cfg = BertConfig(**_bert_net_cfg.__dict__)