!10653 bert on gpu for pre training script supports loss scale

From: @hanhuifeng2020
Reviewed-by: @c_34,@gaoxiong1
Signed-off-by: @c_34
This commit is contained in:
mindspore-ci-bot 2020-12-28 17:21:51 +08:00 committed by Gitee
commit a910086b50
1 changed files with 31 additions and 8 deletions

View File

@ -315,6 +315,15 @@ def tensor_grad_scale(scale, grad):
return grad * reciprocal(scale)
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
grad_overflow = P.FloatStatus()
@_grad_overflow.register("Tensor")
def _tensor_grad_overflow(grad):
return grad_overflow(grad)
class BertTrainOneStepWithLossScaleCell(nn.Cell):
"""
Encapsulation class of bert network training.
@ -347,6 +356,13 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.cast = P.Cast()
if context.get_context("device_target") == "GPU":
self.gpu_target = True
self.float_status = P.FloatStatus()
self.addn = P.AddN()
self.reshape = P.Reshape()
else:
self.gpu_target = False
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_before_grad = P.NPUClearFloatStatus()
@ -383,6 +399,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
scaling_sens = self.loss_scale
else:
scaling_sens = sens
init = False
if not self.gpu_target:
# alloc status and clear should be right before gradoperation
init = self.alloc_status()
self.clear_before_grad(init)
@ -399,8 +417,13 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
grads = self.grad_reducer(grads)
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
if not self.gpu_target:
self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
else:
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
flag_sum = self.addn(flag_sum)
flag_sum = self.reshape(flag_sum, (()))
if self.is_distributed:
# sum overflow flag over devices
flag_reduce = self.allreduce(flag_sum)