diff --git a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py index 50b05aecad1..a4fcbc0aa62 100644 --- a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py +++ b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py @@ -315,6 +315,15 @@ def tensor_grad_scale(scale, grad): return grad * reciprocal(scale) +_grad_overflow = C.MultitypeFuncGraph("_grad_overflow") +grad_overflow = P.FloatStatus() + + +@_grad_overflow.register("Tensor") +def _tensor_grad_overflow(grad): + return grad_overflow(grad) + + class BertTrainOneStepWithLossScaleCell(nn.Cell): """ Encapsulation class of bert network training. @@ -347,9 +356,16 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) self.cast = P.Cast() - self.alloc_status = P.NPUAllocFloatStatus() - self.get_status = P.NPUGetFloatStatus() - self.clear_before_grad = P.NPUClearFloatStatus() + if context.get_context("device_target") == "GPU": + self.gpu_target = True + self.float_status = P.FloatStatus() + self.addn = P.AddN() + self.reshape = P.Reshape() + else: + self.gpu_target = False + self.alloc_status = P.NPUAllocFloatStatus() + self.get_status = P.NPUGetFloatStatus() + self.clear_before_grad = P.NPUClearFloatStatus() self.reduce_sum = P.ReduceSum(keep_dims=False) self.depend_parameter_use = P.ControlDepend(depend_mode=1) self.base = Tensor(1, mstype.float32) @@ -383,9 +399,11 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): scaling_sens = self.loss_scale else: scaling_sens = sens - # alloc status and clear should be right before gradoperation - init = self.alloc_status() - self.clear_before_grad(init) + init = False + if not self.gpu_target: + # alloc status and clear should be right before gradoperation + init = self.alloc_status() + self.clear_before_grad(init) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, @@ -399,8 +417,13 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): grads = self.grad_reducer(grads) grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) - self.get_status(init) - flag_sum = self.reduce_sum(init, (0,)) + if not self.gpu_target: + self.get_status(init) + flag_sum = self.reduce_sum(init, (0,)) + else: + flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) + flag_sum = self.addn(flag_sum) + flag_sum = self.reshape(flag_sum, (())) if self.is_distributed: # sum overflow flag over devices flag_reduce = self.allreduce(flag_sum)