forked from mindspore-Ecosystem/mindspore
!10653 bert on gpu for pre training script supports loss scale
From: @hanhuifeng2020 Reviewed-by: @c_34,@gaoxiong1 Signed-off-by: @c_34
This commit is contained in:
commit
a910086b50
|
@ -315,6 +315,15 @@ def tensor_grad_scale(scale, grad):
|
||||||
return grad * reciprocal(scale)
|
return grad * reciprocal(scale)
|
||||||
|
|
||||||
|
|
||||||
|
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
|
||||||
|
grad_overflow = P.FloatStatus()
|
||||||
|
|
||||||
|
|
||||||
|
@_grad_overflow.register("Tensor")
|
||||||
|
def _tensor_grad_overflow(grad):
|
||||||
|
return grad_overflow(grad)
|
||||||
|
|
||||||
|
|
||||||
class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Encapsulation class of bert network training.
|
Encapsulation class of bert network training.
|
||||||
|
@ -347,9 +356,16 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
||||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
self.alloc_status = P.NPUAllocFloatStatus()
|
if context.get_context("device_target") == "GPU":
|
||||||
self.get_status = P.NPUGetFloatStatus()
|
self.gpu_target = True
|
||||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
self.float_status = P.FloatStatus()
|
||||||
|
self.addn = P.AddN()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
else:
|
||||||
|
self.gpu_target = False
|
||||||
|
self.alloc_status = P.NPUAllocFloatStatus()
|
||||||
|
self.get_status = P.NPUGetFloatStatus()
|
||||||
|
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||||
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
||||||
self.base = Tensor(1, mstype.float32)
|
self.base = Tensor(1, mstype.float32)
|
||||||
|
@ -383,9 +399,11 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
||||||
scaling_sens = self.loss_scale
|
scaling_sens = self.loss_scale
|
||||||
else:
|
else:
|
||||||
scaling_sens = sens
|
scaling_sens = sens
|
||||||
# alloc status and clear should be right before gradoperation
|
init = False
|
||||||
init = self.alloc_status()
|
if not self.gpu_target:
|
||||||
self.clear_before_grad(init)
|
# alloc status and clear should be right before gradoperation
|
||||||
|
init = self.alloc_status()
|
||||||
|
self.clear_before_grad(init)
|
||||||
grads = self.grad(self.network, weights)(input_ids,
|
grads = self.grad(self.network, weights)(input_ids,
|
||||||
input_mask,
|
input_mask,
|
||||||
token_type_id,
|
token_type_id,
|
||||||
|
@ -399,8 +417,13 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
||||||
grads = self.grad_reducer(grads)
|
grads = self.grad_reducer(grads)
|
||||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
||||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||||
self.get_status(init)
|
if not self.gpu_target:
|
||||||
flag_sum = self.reduce_sum(init, (0,))
|
self.get_status(init)
|
||||||
|
flag_sum = self.reduce_sum(init, (0,))
|
||||||
|
else:
|
||||||
|
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
|
||||||
|
flag_sum = self.addn(flag_sum)
|
||||||
|
flag_sum = self.reshape(flag_sum, (()))
|
||||||
if self.is_distributed:
|
if self.is_distributed:
|
||||||
# sum overflow flag over devices
|
# sum overflow flag over devices
|
||||||
flag_reduce = self.allreduce(flag_sum)
|
flag_reduce = self.allreduce(flag_sum)
|
||||||
|
|
Loading…
Reference in New Issue