script update for bert
This commit is contained in:
parent
ca6da6751f
commit
af4923123c
|
@ -117,8 +117,7 @@ def run_pretrain():
|
|||
decay_params = list(filter(cfg.Lamb.decay_filter, params))
|
||||
other_params = list(filter(lambda x: x not in decay_params, params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
|
||||
{'params': other_params},
|
||||
{'order_params': params}]
|
||||
{'params': other_params}]
|
||||
optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps)
|
||||
elif cfg.optimizer == 'Momentum':
|
||||
optimizer = Momentum(net_with_loss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
|
||||
|
@ -133,8 +132,7 @@ def run_pretrain():
|
|||
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
|
||||
other_params = list(filter(lambda x: x not in decay_params, params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
|
||||
{'params': other_params, 'weight_decay': 0.0},
|
||||
{'order_params': params}]
|
||||
{'params': other_params, 'weight_decay': 0.0}]
|
||||
|
||||
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
||||
else:
|
||||
|
|
|
@ -22,7 +22,7 @@ from mindspore.ops import operations as P
|
|||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
|
@ -55,7 +55,7 @@ class BertFinetuneCell(nn.Cell):
|
|||
|
||||
super(BertFinetuneCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation('grad',
|
||||
get_by_list=True,
|
||||
|
@ -158,7 +158,7 @@ class BertSquadCell(nn.Cell):
|
|||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
super(BertSquadCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
|
||||
self.reducer_flag = False
|
||||
|
|
|
@ -21,7 +21,7 @@ from mindspore.ops import operations as P
|
|||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
|
@ -270,7 +270,7 @@ class BertTrainOneStepCell(nn.Cell):
|
|||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(BertTrainOneStepCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
|
@ -349,7 +349,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation('grad',
|
||||
get_by_list=True,
|
||||
|
|
|
@ -133,6 +133,9 @@ class BertLearningRate(LearningRateSchedule):
|
|||
"""
|
||||
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
|
||||
super(BertLearningRate, self).__init__()
|
||||
self.warmup_flag = False
|
||||
if warmup_steps > 0:
|
||||
self.warmup_flag = True
|
||||
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
|
||||
self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
|
||||
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
|
||||
|
@ -142,8 +145,11 @@ class BertLearningRate(LearningRateSchedule):
|
|||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, global_step):
|
||||
decay_lr = self.decay_lr(global_step)
|
||||
if self.warmup_flag:
|
||||
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
|
||||
warmup_lr = self.warmup_lr(global_step)
|
||||
decay_lr = self.decay_lr(global_step)
|
||||
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
|
||||
else:
|
||||
lr = decay_lr
|
||||
return lr
|
||||
|
|
Loading…
Reference in New Issue