diff --git a/model_zoo/official/cv/crnn/src/crnn_for_train.py b/model_zoo/official/cv/crnn/src/crnn_for_train.py index 8d767ade348..fad288c36f4 100644 --- a/model_zoo/official/cv/crnn/src/crnn_for_train.py +++ b/model_zoo/official/cv/crnn/src/crnn_for_train.py @@ -14,8 +14,7 @@ # ============================================================================ """Automatic differentiation with grad clip.""" import numpy as np -from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean, - _get_parallel_mode) +from mindspore import context from mindspore.context import ParallelMode from mindspore.common import dtype as mstype from mindspore.ops import composite as C @@ -89,13 +88,11 @@ class TrainOneStepCellWithGradClip(Cell): self.cast = P.Cast() self.concat = P.Concat(axis=0) self.ten = Tensor(np.array([10.0]).astype(np.float32)) - parallel_mode = _get_parallel_mode() + parallel_mode = context.get_auto_parallel_context("parallel_mode") if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): self.reducer_flag = True if self.reducer_flag: - mean = _get_gradients_mean() - degree = _get_device_num() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) + self.grad_reducer = DistributedGradReducer(optimizer.parameters) def construct(self, data, label): weights = self.weights