diff --git a/model_zoo/official/nlp/bert/pretrain_config.yaml b/model_zoo/official/nlp/bert/pretrain_config.yaml index f2b4b335f8b..14492574a6f 100644 --- a/model_zoo/official/nlp/bert/pretrain_config.yaml +++ b/model_zoo/official/nlp/bert/pretrain_config.yaml @@ -35,6 +35,7 @@ schema_dir: '' # ============================================================================== # pretrain related batch_size: 32 +# Available: [base, nezha, large, large_acc] bert_network: 'base' loss_scale_value: 65536 scale_factor: 2 @@ -121,6 +122,24 @@ nezha_net_cfg: # large large_batch_size: 24 large_net_cfg: + seq_length: 512 + vocab_size: 30522 + hidden_size: 1024 + num_hidden_layers: 24 + num_attention_heads: 16 + intermediate_size: 4096 + hidden_act: "gelu" + hidden_dropout_prob: 0.1 + attention_probs_dropout_prob: 0.1 + max_position_embeddings: 512 + type_vocab_size: 2 + initializer_range: 0.02 + use_relative_positions: False + dtype: mstype.float32 + compute_type: mstype.float16 +# Accelerated large network which is only supported in Ascend yet. +large_acc_batch_size: 24 +large_acc_net_cfg: seq_length: 512 vocab_size: 30522 hidden_size: 1024 @@ -137,6 +156,7 @@ large_net_cfg: dtype: mstype.float32 compute_type: mstype.float16 + --- # Help description for each configuration enable_modelarts: "Whether training on modelarts, default: False" diff --git a/model_zoo/official/nlp/bert/src/model_utils/config.py b/model_zoo/official/nlp/bert/src/model_utils/config.py index 7c653c17b9e..49680622182 100644 --- a/model_zoo/official/nlp/bert/src/model_utils/config.py +++ b/model_zoo/official/nlp/bert/src/model_utils/config.py @@ -140,6 +140,9 @@ def extra_operations(cfg): elif cfg.bert_network == 'large': cfg.batch_size = cfg.large_batch_size _bert_net_cfg = cfg.large_net_cfg + elif cfg.bert_network == 'large_acc': + cfg.batch_size = cfg.large_acc_batch_size + _bert_net_cfg = cfg.large_acc_net_cfg else: pass cfg.bert_net_cfg = BertConfig(**_bert_net_cfg.__dict__)