diff --git a/mindspore/model_zoo/Bert_NEZHA/bert_for_pre_training.py b/mindspore/model_zoo/Bert_NEZHA/bert_for_pre_training.py index bc51ba5d483..046b2adbe2c 100644 --- a/mindspore/model_zoo/Bert_NEZHA/bert_for_pre_training.py +++ b/mindspore/model_zoo/Bert_NEZHA/bert_for_pre_training.py @@ -370,7 +370,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): self.parallel_mode = context.get_auto_parallel_context("parallel_mode") if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: self.reducer_flag = True - self.grad_reducer = None + self.grad_reducer = F.identity if self.reducer_flag: mean = context.get_auto_parallel_context("mirror_mean") degree = get_group_size() @@ -428,9 +428,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): mstype.float32)) grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE) - if self.reducer_flag: - # apply grad reducer on grads - grads = self.grad_reducer(grads) + # apply grad reducer on grads + grads = self.grad_reducer(grads) self.get_status(init) flag_sum = self.reduce_sum(init, (0,)) if self.is_distributed: diff --git a/mindspore/nn/wrap/loss_scale.py b/mindspore/nn/wrap/loss_scale.py index c6d61e6983a..ba8e6cbb7c8 100644 --- a/mindspore/nn/wrap/loss_scale.py +++ b/mindspore/nn/wrap/loss_scale.py @@ -220,7 +220,7 @@ class TrainOneStepWithLossScaleCell(Cell): self.depend_parameter_use = ControlDepend(depend_mode=1) self.allreduce = P.AllReduce() self.parallel_mode = _get_parallel_mode() - self.grad_reducer = None + self.grad_reducer = F.identity self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL] if self.reducer_flag: mean = _get_mirror_mean() @@ -250,9 +250,8 @@ class TrainOneStepWithLossScaleCell(Cell): scaling_sens = sens grads = self.grad(self.network, weights)(data, label, F.cast(scaling_sens, F.dtype(loss))) grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads) - if self.reducer_flag: - # apply grad reducer on grads - grads = self.grad_reducer(grads) + # apply grad reducer on grads + grads = self.grad_reducer(grads) # get the overflow buffer if not self.gpu_target: self.get_status(init)