!19840 adjust super params for bert thor

Merge pull request !19840 from wangshuangling/master
This commit is contained in:
i-robot 2021-07-10 06:24:57 +00:00 committed by Gitee
commit edd75a264b
4 changed files with 16 additions and 17 deletions

View File

@ -69,17 +69,17 @@ Momentum:
momentum: 0.9
Thor:
lr_max: 0.0034
lr_min: 0.00003244 # 3.244e-5
lr_power: 1.0
lr_max: 0.006464
lr_min: 0.000001 # 1e-6
lr_power: 2.0
lr_total_steps: 30000
damping_max: 0.05 # 5e-2
damping_max: 0.007035
damping_min: 0.000001 # 1e-6
damping_power: 1.0
damping_power: 4.0
damping_total_steps: 30000
momentum: 0.9
weight_decay: 0.0005 # 5e-4,
loss_scale: 1.0
weight_decay: 0.00001 # 1e-5
loss_scale: 1024.0
frequency: 100
# ==============================================================================
# base

View File

@ -23,17 +23,17 @@ cfg = edict({
'bert_network': 'large',
'optimizer': 'Thor',
'Thor': edict({
'lr_max': 0.0034,
'lr_min': 3.244e-5,
'lr_power': 1.0,
'lr_max': 0.006464,
'lr_min': 1e-6,
'lr_power': 2.0,
'lr_total_steps': 30000,
'damping_max': 5e-2,
'damping_max': 0.007035,
'damping_min': 1e-6,
'damping_power': 1.0,
'damping_power': 4.0,
'damping_total_steps': 30000,
'momentum': 0.9,
'weight_decay': 5e-4,
'loss_scale': 1.0,
'weight_decay': 0.0,
'loss_scale': 1024.0,
'frequency': 100,
}),
})
@ -91,7 +91,7 @@ if cfg.bert_network == 'large':
num_hidden_layers=24,
num_attention_heads=16,
intermediate_size=4096,
hidden_act="gelu",
hidden_act="fast_gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,

View File

@ -31,7 +31,6 @@ def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None,
for file_name in files:
if "tfrecord" in file_name:
data_files.append(os.path.join(data_dir, file_name))
data_files = sorted(data_files)
data_set = ds.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None,
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"],

View File

@ -40,7 +40,7 @@ bert_net_cfg = BertConfig(
num_hidden_layers=24,
num_attention_heads=16,
intermediate_size=4096,
hidden_act="gelu",
hidden_act="fast_gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=512,