forked from mindspore-Ecosystem/mindspore
Add new optimizer THOR option to BERT pretrain script.
This commit is contained in:
parent
856d6f58cf
commit
67a4c62b4b
|
@ -28,7 +28,8 @@ from mindspore.context import ParallelMode
|
|||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay
|
||||
from mindspore.train.train_thor import ConvertModelUtils
|
||||
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay, THOR
|
||||
from mindspore import log as logger
|
||||
from mindspore.common import set_seed
|
||||
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \
|
||||
|
@ -90,8 +91,27 @@ def _get_optimizer(args_opt, network):
|
|||
optimizer = AdamWeightDecayForBert(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
||||
else:
|
||||
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
||||
elif cfg.optimizer == "Thor":
|
||||
from src.utils import get_bert_thor_lr, get_bert_thor_damping
|
||||
lr = get_bert_thor_lr()
|
||||
damping = get_bert_thor_damping()
|
||||
split_indices = None
|
||||
if bert_net_cfg.num_hidden_layers == 12:
|
||||
if bert_net_cfg.use_relative_positions:
|
||||
split_indices = [29, 58, 87, 116, 145, 174, 203, 217]
|
||||
else:
|
||||
split_indices = [28, 55, 82, 109, 136, 163, 190, 205]
|
||||
elif bert_net_cfg.num_hidden_layers == 24:
|
||||
if bert_net_cfg.use_relative_positions:
|
||||
split_indices = [30, 90, 150, 210, 270, 330, 390, 421]
|
||||
else:
|
||||
split_indices = [38, 93, 148, 203, 258, 313, 368, 397]
|
||||
optimizer = THOR(network, lr, damping, cfg.Thor.momentum,
|
||||
cfg.Thor.weight_decay, cfg.Thor.loss_scale, cfg.batch_size,
|
||||
decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
||||
split_indices=split_indices)
|
||||
else:
|
||||
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]".
|
||||
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay, Thor]".
|
||||
format(cfg.optimizer))
|
||||
return optimizer
|
||||
|
||||
|
@ -244,6 +264,8 @@ def run_pretrain():
|
|||
net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer)
|
||||
|
||||
model = Model(net_with_grads)
|
||||
model = ConvertModelUtils().convert_to_thor_model(model, network=net_with_grads, optimizer=optimizer,
|
||||
frequency=cfg.Thor.frequency)
|
||||
model.train(new_repeat_count, ds, callbacks=callback,
|
||||
dataset_sink_mode=(args_opt.enable_data_sink == "true"), sink_size=args_opt.data_sink_steps)
|
||||
|
||||
|
|
Loading…
Reference in New Issue