forked from mindspore-Ecosystem/mindspore
!5632 Add clip_by_global_nrom in bert
Merge pull request !5632 from chenhaozhe/add-global-norm-to-bert
This commit is contained in:
commit
44a9c25251
|
@ -179,12 +179,14 @@ def run_pretrain():
|
|||
|
||||
if args_opt.accumulation_steps <= 1:
|
||||
net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer,
|
||||
scale_update_cell=update_cell)
|
||||
scale_update_cell=update_cell,
|
||||
enable_global_norm=cfg.enable_global_norm)
|
||||
else:
|
||||
accumulation_steps = args_opt.accumulation_steps
|
||||
net_with_grads = BertTrainAccumulateStepsWithLossScaleCell(net_with_loss, optimizer=optimizer,
|
||||
scale_update_cell=update_cell,
|
||||
accumulation_steps=accumulation_steps)
|
||||
accumulation_steps=accumulation_steps,
|
||||
enable_global_norm=cfg.enable_global_norm)
|
||||
else:
|
||||
net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer)
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@ from mindspore.communication.management import get_group_size
|
|||
from mindspore import context
|
||||
from mindspore.ops import _selected_ops
|
||||
from .bert_model import BertModel
|
||||
from .utils import ClipByGlobalNorm
|
||||
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 1.0
|
||||
|
@ -348,11 +349,12 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||
"""
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
def __init__(self, network, optimizer, scale_update_cell=None, enable_global_norm=False):
|
||||
super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.enable_global_norm = enable_global_norm
|
||||
self.grad = C.GradOperation(get_by_list=True,
|
||||
sens_param=True)
|
||||
self.reducer_flag = False
|
||||
|
@ -419,7 +421,10 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
if self.enable_global_norm:
|
||||
grads = ClipByGlobalNorm()(grads)
|
||||
else:
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
self.get_status(init)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
if self.is_distributed:
|
||||
|
@ -474,12 +479,13 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell):
|
|||
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
|
||||
batch_size * accumulation_steps. Default: 1.
|
||||
"""
|
||||
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1):
|
||||
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False):
|
||||
super(BertTrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.accumulation_steps = accumulation_steps
|
||||
self.enable_global_norm = enable_global_norm
|
||||
self.one = Tensor(np.array([1]).astype(np.int32))
|
||||
self.zero = Tensor(np.array([0]).astype(np.int32))
|
||||
self.local_step = Parameter(initializer(0, [1], mstype.int32), name="local_step")
|
||||
|
@ -580,7 +586,10 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell):
|
|||
grads = self.grad_reducer(self.accu_grads)
|
||||
scaling = scaling_sens * self.degree * self.accumulation_steps
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling), grads)
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
if self.enable_global_norm:
|
||||
grads = ClipByGlobalNorm()(grad)
|
||||
else:
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
accu_overflow = self.overflow_reducer(accu_overflow)
|
||||
F.control_depend(grads, accu_overflow)
|
||||
overflow = self.less_equal(self.base, accu_overflow)
|
||||
|
|
|
@ -24,6 +24,7 @@ cfg = edict({
|
|||
'scale_factor': 2,
|
||||
'scale_window': 1000,
|
||||
'optimizer': 'Lamb',
|
||||
'enable_global_norm': False,
|
||||
'AdamWeightDecay': edict({
|
||||
'learning_rate': 3e-5,
|
||||
'end_learning_rate': 0.0,
|
||||
|
@ -115,6 +116,5 @@ if cfg.bert_network == 'large':
|
|||
input_mask_from_dataset=True,
|
||||
token_type_ids_from_dataset=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16,
|
||||
enable_fused_layernorm=True
|
||||
compute_type=mstype.float16
|
||||
)
|
||||
|
|
|
@ -23,12 +23,62 @@ import numpy as np
|
|||
import mindspore.nn as nn
|
||||
from mindspore import log as logger
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
|
||||
|
||||
|
||||
get_square_sum = C.MultitypeFuncGraph("get_square_sum")
|
||||
@get_square_sum.register("Tensor")
|
||||
def _get_square_sum(grad):
|
||||
norm = P.ReduceSum(False)(F.square(grad), ())
|
||||
norm = F.expand_dims(F.cast(norm, mstype.float32), 0)
|
||||
return norm
|
||||
|
||||
|
||||
apply_global_norm = C.MultitypeFuncGraph("apply_global_norm")
|
||||
@apply_global_norm.register("Tensor", "Tensor", "Tensor")
|
||||
def _apply_global_norm(clip_norm, global_norm, grad):
|
||||
grad = grad * clip_norm / global_norm
|
||||
return grad
|
||||
|
||||
|
||||
class GlobalNorm(nn.Cell):
|
||||
"""
|
||||
Calculate the global norm value of given tensors
|
||||
"""
|
||||
def __init__(self):
|
||||
super(GlobalNorm, self).__init__()
|
||||
self.norm = nn.Norm()
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def construct(self, grads):
|
||||
square_sum = self.hyper_map(get_square_sum, grads)
|
||||
global_norms = F.sqrt(F.addn(square_sum) / F.scalar_to_array(len(square_sum)))
|
||||
return global_norms
|
||||
|
||||
|
||||
class ClipByGlobalNorm(nn.Cell):
|
||||
"""
|
||||
Clip grads by global norm
|
||||
"""
|
||||
def __init__(self, clip_norm=1.0):
|
||||
super(ClipByGlobalNorm, self).__init__()
|
||||
self.global_norm = GlobalNorm()
|
||||
self.clip_norm = Tensor([clip_norm], mstype.float32)
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def construct(self, grads):
|
||||
global_norm = self.global_norm(grads)
|
||||
cond = P.GreaterEqual()(global_norm, self.clip_norm)
|
||||
global_norm = F.select(cond, global_norm, self.clip_norm)
|
||||
grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads)
|
||||
return grads
|
||||
|
||||
|
||||
class CrossEntropyCalculation(nn.Cell):
|
||||
"""
|
||||
Cross Entropy loss
|
||||
|
|
Loading…
Reference in New Issue