diff --git a/mindspore/model_zoo/Bert_NEZHA/bert_for_pre_training.py b/mindspore/model_zoo/Bert_NEZHA/bert_for_pre_training.py index 22a212a489c..bc51ba5d483 100644 --- a/mindspore/model_zoo/Bert_NEZHA/bert_for_pre_training.py +++ b/mindspore/model_zoo/Bert_NEZHA/bert_for_pre_training.py @@ -403,9 +403,6 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): sens=None): """Defines the computation performed.""" weights = self.weights - # alloc status - init = self.alloc_status() - self.clear_before_grad(init) loss = self.network(input_ids, input_mask, token_type_id, @@ -417,6 +414,9 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): scaling_sens = self.loss_scale else: scaling_sens = sens + # alloc status and clear should be right before gradoperation + init = self.alloc_status() + self.clear_before_grad(init) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id,