!17881 Substitute internal api used in CRNN

From: @c_34
Reviewed-by: @wuxuejian,@guoqi1024
Signed-off-by: @wuxuejian,@guoqi1024
This commit is contained in:
mindspore-ci-bot 2021-06-07 16:34:32 +08:00 committed by Gitee
commit 5358d44888
1 changed files with 3 additions and 6 deletions

View File

@ -14,8 +14,7 @@
# ============================================================================ # ============================================================================
"""Automatic differentiation with grad clip.""" """Automatic differentiation with grad clip."""
import numpy as np import numpy as np
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean, from mindspore import context
_get_parallel_mode)
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.ops import composite as C from mindspore.ops import composite as C
@ -89,13 +88,11 @@ class TrainOneStepCellWithGradClip(Cell):
self.cast = P.Cast() self.cast = P.Cast()
self.concat = P.Concat(axis=0) self.concat = P.Concat(axis=0)
self.ten = Tensor(np.array([10.0]).astype(np.float32)) self.ten = Tensor(np.array([10.0]).astype(np.float32))
parallel_mode = _get_parallel_mode() parallel_mode = context.get_auto_parallel_context("parallel_mode")
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True self.reducer_flag = True
if self.reducer_flag: if self.reducer_flag:
mean = _get_gradients_mean() self.grad_reducer = DistributedGradReducer(optimizer.parameters)
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
def construct(self, data, label): def construct(self, data, label):
weights = self.weights weights = self.weights