forked from mindspore-Ecosystem/mindspore
remove unused optimizer inf config file
This commit is contained in:
parent
651b1c3577
commit
56c8c346dd
|
@ -25,12 +25,11 @@ from src.config import cfg
|
||||||
from src.dataset import create_bert_dataset
|
from src.dataset import create_bert_dataset
|
||||||
from src.lr_generator import get_bert_lr, get_bert_damping
|
from src.lr_generator import get_bert_lr, get_bert_damping
|
||||||
from src.model_thor import Model
|
from src.model_thor import Model
|
||||||
from src.utils import LossCallBack, BertLearningRate
|
from src.utils import LossCallBack
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
import mindspore.communication.management as D
|
import mindspore.communication.management as D
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay
|
|
||||||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
|
@ -68,38 +67,8 @@ def _set_bert_all_reduce_split():
|
||||||
|
|
||||||
|
|
||||||
def _get_optimizer(args_opt, network):
|
def _get_optimizer(args_opt, network):
|
||||||
"""get bert optimizer, support Lamb, Momentum, AdamWeightDecay and Thor."""
|
"""get thor optimizer."""
|
||||||
if cfg.optimizer == 'Lamb':
|
if cfg.optimizer == "Thor":
|
||||||
lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate,
|
|
||||||
end_learning_rate=cfg.Lamb.end_learning_rate,
|
|
||||||
warmup_steps=cfg.Lamb.warmup_steps,
|
|
||||||
decay_steps=args_opt.train_steps,
|
|
||||||
power=cfg.Lamb.power)
|
|
||||||
params = network.trainable_params()
|
|
||||||
decay_params = list(filter(cfg.Lamb.decay_filter, params))
|
|
||||||
other_params = list(filter(lambda x: not cfg.Lamb.decay_filter(x), params))
|
|
||||||
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
|
|
||||||
{'params': other_params},
|
|
||||||
{'order_params': params}]
|
|
||||||
optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps)
|
|
||||||
elif cfg.optimizer == 'Momentum':
|
|
||||||
optimizer = Momentum(network.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
|
|
||||||
momentum=cfg.Momentum.momentum)
|
|
||||||
elif cfg.optimizer == 'AdamWeightDecay':
|
|
||||||
lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
|
|
||||||
end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
|
|
||||||
warmup_steps=cfg.AdamWeightDecay.warmup_steps,
|
|
||||||
decay_steps=args_opt.train_steps,
|
|
||||||
power=cfg.AdamWeightDecay.power)
|
|
||||||
params = network.trainable_params()
|
|
||||||
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
|
|
||||||
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
|
|
||||||
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
|
|
||||||
{'params': other_params, 'weight_decay': 0.0},
|
|
||||||
{'order_params': params}]
|
|
||||||
|
|
||||||
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
|
||||||
elif cfg.optimizer == "Thor":
|
|
||||||
if args_opt.distribute == "true":
|
if args_opt.distribute == "true":
|
||||||
from src.thor_for_bert_arg import THOR
|
from src.thor_for_bert_arg import THOR
|
||||||
else:
|
else:
|
||||||
|
@ -112,8 +81,7 @@ def _get_optimizer(args_opt, network):
|
||||||
cfg.Thor.weight_decay, cfg.Thor.loss_scale, bert_net_cfg.num_hidden_layers,
|
cfg.Thor.weight_decay, cfg.Thor.loss_scale, bert_net_cfg.num_hidden_layers,
|
||||||
bert_net_cfg.batch_size, damping)
|
bert_net_cfg.batch_size, damping)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay, Thor]".
|
raise ValueError("Don't support optimizer {}, only support [Thor]".format(cfg.optimizer))
|
||||||
format(cfg.optimizer))
|
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -20,28 +20,6 @@ from easydict import EasyDict as edict
|
||||||
cfg = edict({
|
cfg = edict({
|
||||||
'bert_network': 'large',
|
'bert_network': 'large',
|
||||||
'optimizer': 'Thor',
|
'optimizer': 'Thor',
|
||||||
'AdamWeightDecay': edict({
|
|
||||||
'learning_rate': 3e-5,
|
|
||||||
'end_learning_rate': 1e-10,
|
|
||||||
'power': 5.0,
|
|
||||||
'weight_decay': 1e-5,
|
|
||||||
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
|
||||||
'eps': 1e-6,
|
|
||||||
'warmup_steps': 10000,
|
|
||||||
}),
|
|
||||||
'Lamb': edict({
|
|
||||||
'learning_rate': 3e-5,
|
|
||||||
'end_learning_rate': 1e-10,
|
|
||||||
'power': 10.0,
|
|
||||||
'warmup_steps': 10000,
|
|
||||||
'weight_decay': 0.01,
|
|
||||||
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
|
||||||
'eps': 1e-6,
|
|
||||||
}),
|
|
||||||
'Momentum': edict({
|
|
||||||
'learning_rate': 2e-5,
|
|
||||||
'momentum': 0.9,
|
|
||||||
}),
|
|
||||||
'Thor': edict({
|
'Thor': edict({
|
||||||
'momentum': 0.9,
|
'momentum': 0.9,
|
||||||
'weight_decay': 5e-4,
|
'weight_decay': 5e-4,
|
||||||
|
|
Loading…
Reference in New Issue