remove unused optimizer inf config file

This commit is contained in:
mwang 2020-12-08 20:12:11 +08:00
parent 651b1c3577
commit 56c8c346dd
2 changed files with 4 additions and 58 deletions

View File

@ -25,12 +25,11 @@ from src.config import cfg
from src.dataset import create_bert_dataset
from src.lr_generator import get_bert_lr, get_bert_damping
from src.model_thor import Model
from src.utils import LossCallBack, BertLearningRate
from src.utils import LossCallBack
import mindspore.common.dtype as mstype
import mindspore.communication.management as D
from mindspore import context
from mindspore import log as logger
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.context import ParallelMode
@ -68,38 +67,8 @@ def _set_bert_all_reduce_split():
def _get_optimizer(args_opt, network):
"""get bert optimizer, support Lamb, Momentum, AdamWeightDecay and Thor."""
if cfg.optimizer == 'Lamb':
lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate,
end_learning_rate=cfg.Lamb.end_learning_rate,
warmup_steps=cfg.Lamb.warmup_steps,
decay_steps=args_opt.train_steps,
power=cfg.Lamb.power)
params = network.trainable_params()
decay_params = list(filter(cfg.Lamb.decay_filter, params))
other_params = list(filter(lambda x: not cfg.Lamb.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
{'params': other_params},
{'order_params': params}]
optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps)
elif cfg.optimizer == 'Momentum':
optimizer = Momentum(network.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
momentum=cfg.Momentum.momentum)
elif cfg.optimizer == 'AdamWeightDecay':
lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
warmup_steps=cfg.AdamWeightDecay.warmup_steps,
decay_steps=args_opt.train_steps,
power=cfg.AdamWeightDecay.power)
params = network.trainable_params()
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0},
{'order_params': params}]
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
elif cfg.optimizer == "Thor":
"""get thor optimizer."""
if cfg.optimizer == "Thor":
if args_opt.distribute == "true":
from src.thor_for_bert_arg import THOR
else:
@ -112,8 +81,7 @@ def _get_optimizer(args_opt, network):
cfg.Thor.weight_decay, cfg.Thor.loss_scale, bert_net_cfg.num_hidden_layers,
bert_net_cfg.batch_size, damping)
else:
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay, Thor]".
format(cfg.optimizer))
raise ValueError("Don't support optimizer {}, only support [Thor]".format(cfg.optimizer))
return optimizer

View File

@ -20,28 +20,6 @@ from easydict import EasyDict as edict
cfg = edict({
'bert_network': 'large',
'optimizer': 'Thor',
'AdamWeightDecay': edict({
'learning_rate': 3e-5,
'end_learning_rate': 1e-10,
'power': 5.0,
'weight_decay': 1e-5,
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
'eps': 1e-6,
'warmup_steps': 10000,
}),
'Lamb': edict({
'learning_rate': 3e-5,
'end_learning_rate': 1e-10,
'power': 10.0,
'warmup_steps': 10000,
'weight_decay': 0.01,
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
'eps': 1e-6,
}),
'Momentum': edict({
'learning_rate': 2e-5,
'momentum': 0.9,
}),
'Thor': edict({
'momentum': 0.9,
'weight_decay': 5e-4,