forked from mindspore-Ecosystem/mindspore
add order params for bert to improve performance
This commit is contained in:
parent
183cf5cf5d
commit
29e35a31c0
|
@ -106,6 +106,7 @@ def run_pretrain():
|
||||||
new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
|
new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
|
||||||
else:
|
else:
|
||||||
args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size()
|
args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size()
|
||||||
|
logger.info("train steps: {}".format(args_opt.train_steps))
|
||||||
|
|
||||||
if cfg.optimizer == 'Lamb':
|
if cfg.optimizer == 'Lamb':
|
||||||
lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate,
|
lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate,
|
||||||
|
@ -117,7 +118,8 @@ def run_pretrain():
|
||||||
decay_params = list(filter(cfg.Lamb.decay_filter, params))
|
decay_params = list(filter(cfg.Lamb.decay_filter, params))
|
||||||
other_params = list(filter(lambda x: x not in decay_params, params))
|
other_params = list(filter(lambda x: x not in decay_params, params))
|
||||||
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
|
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
|
||||||
{'params': other_params}]
|
{'params': other_params},
|
||||||
|
{'order_params': params}]
|
||||||
optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps)
|
optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps)
|
||||||
elif cfg.optimizer == 'Momentum':
|
elif cfg.optimizer == 'Momentum':
|
||||||
optimizer = Momentum(net_with_loss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
|
optimizer = Momentum(net_with_loss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
|
||||||
|
@ -132,7 +134,8 @@ def run_pretrain():
|
||||||
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
|
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
|
||||||
other_params = list(filter(lambda x: x not in decay_params, params))
|
other_params = list(filter(lambda x: x not in decay_params, params))
|
||||||
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
|
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
|
||||||
{'params': other_params, 'weight_decay': 0.0}]
|
{'params': other_params, 'weight_decay': 0.0},
|
||||||
|
{'order_params': params}]
|
||||||
|
|
||||||
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -26,7 +26,7 @@ cfg = edict({
|
||||||
'optimizer': 'Lamb',
|
'optimizer': 'Lamb',
|
||||||
'AdamWeightDecay': edict({
|
'AdamWeightDecay': edict({
|
||||||
'learning_rate': 3e-5,
|
'learning_rate': 3e-5,
|
||||||
'end_learning_rate': 1e-10,
|
'end_learning_rate': 0.0,
|
||||||
'power': 5.0,
|
'power': 5.0,
|
||||||
'weight_decay': 1e-5,
|
'weight_decay': 1e-5,
|
||||||
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
||||||
|
@ -35,7 +35,7 @@ cfg = edict({
|
||||||
}),
|
}),
|
||||||
'Lamb': edict({
|
'Lamb': edict({
|
||||||
'learning_rate': 3e-5,
|
'learning_rate': 3e-5,
|
||||||
'end_learning_rate': 1e-10,
|
'end_learning_rate': 0.0,
|
||||||
'power': 10.0,
|
'power': 10.0,
|
||||||
'warmup_steps': 10000,
|
'warmup_steps': 10000,
|
||||||
'weight_decay': 0.01,
|
'weight_decay': 0.01,
|
||||||
|
|
Loading…
Reference in New Issue