forked from mindspore-Ecosystem/mindspore
!9663 remove unused parameters in config
From: @sl_wang Reviewed-by: @wang_zi_dong,@kisnwang Signed-off-by: @wang_zi_dong
This commit is contained in:
commit
7bfe1a5d34
|
@ -25,12 +25,11 @@ from src.config import cfg
|
|||
from src.dataset import create_bert_dataset
|
||||
from src.lr_generator import get_bert_lr, get_bert_damping
|
||||
from src.model_thor import Model
|
||||
from src.utils import LossCallBack, BertLearningRate
|
||||
from src.utils import LossCallBack
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.communication.management as D
|
||||
from mindspore import context
|
||||
from mindspore import log as logger
|
||||
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay
|
||||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||
from mindspore.context import ParallelMode
|
||||
|
@ -68,38 +67,8 @@ def _set_bert_all_reduce_split():
|
|||
|
||||
|
||||
def _get_optimizer(args_opt, network):
|
||||
"""get bert optimizer, support Lamb, Momentum, AdamWeightDecay and Thor."""
|
||||
if cfg.optimizer == 'Lamb':
|
||||
lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate,
|
||||
end_learning_rate=cfg.Lamb.end_learning_rate,
|
||||
warmup_steps=cfg.Lamb.warmup_steps,
|
||||
decay_steps=args_opt.train_steps,
|
||||
power=cfg.Lamb.power)
|
||||
params = network.trainable_params()
|
||||
decay_params = list(filter(cfg.Lamb.decay_filter, params))
|
||||
other_params = list(filter(lambda x: not cfg.Lamb.decay_filter(x), params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
|
||||
{'params': other_params},
|
||||
{'order_params': params}]
|
||||
optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps)
|
||||
elif cfg.optimizer == 'Momentum':
|
||||
optimizer = Momentum(network.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
|
||||
momentum=cfg.Momentum.momentum)
|
||||
elif cfg.optimizer == 'AdamWeightDecay':
|
||||
lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
|
||||
end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
|
||||
warmup_steps=cfg.AdamWeightDecay.warmup_steps,
|
||||
decay_steps=args_opt.train_steps,
|
||||
power=cfg.AdamWeightDecay.power)
|
||||
params = network.trainable_params()
|
||||
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
|
||||
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
|
||||
{'params': other_params, 'weight_decay': 0.0},
|
||||
{'order_params': params}]
|
||||
|
||||
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
||||
elif cfg.optimizer == "Thor":
|
||||
"""get thor optimizer."""
|
||||
if cfg.optimizer == "Thor":
|
||||
if args_opt.distribute == "true":
|
||||
from src.thor_for_bert_arg import THOR
|
||||
else:
|
||||
|
@ -112,8 +81,7 @@ def _get_optimizer(args_opt, network):
|
|||
cfg.Thor.weight_decay, cfg.Thor.loss_scale, bert_net_cfg.num_hidden_layers,
|
||||
bert_net_cfg.batch_size, damping)
|
||||
else:
|
||||
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay, Thor]".
|
||||
format(cfg.optimizer))
|
||||
raise ValueError("Don't support optimizer {}, only support [Thor]".format(cfg.optimizer))
|
||||
return optimizer
|
||||
|
||||
|
||||
|
|
|
@ -20,28 +20,6 @@ from easydict import EasyDict as edict
|
|||
cfg = edict({
|
||||
'bert_network': 'large',
|
||||
'optimizer': 'Thor',
|
||||
'AdamWeightDecay': edict({
|
||||
'learning_rate': 3e-5,
|
||||
'end_learning_rate': 1e-10,
|
||||
'power': 5.0,
|
||||
'weight_decay': 1e-5,
|
||||
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
||||
'eps': 1e-6,
|
||||
'warmup_steps': 10000,
|
||||
}),
|
||||
'Lamb': edict({
|
||||
'learning_rate': 3e-5,
|
||||
'end_learning_rate': 1e-10,
|
||||
'power': 10.0,
|
||||
'warmup_steps': 10000,
|
||||
'weight_decay': 0.01,
|
||||
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
||||
'eps': 1e-6,
|
||||
}),
|
||||
'Momentum': edict({
|
||||
'learning_rate': 2e-5,
|
||||
'momentum': 0.9,
|
||||
}),
|
||||
'Thor': edict({
|
||||
'momentum': 0.9,
|
||||
'weight_decay': 5e-4,
|
||||
|
|
Loading…
Reference in New Issue