!17881 Substitute internal api used in CRNN
From: @c_34 Reviewed-by: @wuxuejian,@guoqi1024 Signed-off-by: @wuxuejian,@guoqi1024
This commit is contained in:
commit
5358d44888
|
@ -14,8 +14,7 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Automatic differentiation with grad clip."""
|
"""Automatic differentiation with grad clip."""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
|
from mindspore import context
|
||||||
_get_parallel_mode)
|
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from mindspore.ops import composite as C
|
from mindspore.ops import composite as C
|
||||||
|
@ -89,13 +88,11 @@ class TrainOneStepCellWithGradClip(Cell):
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
self.concat = P.Concat(axis=0)
|
self.concat = P.Concat(axis=0)
|
||||||
self.ten = Tensor(np.array([10.0]).astype(np.float32))
|
self.ten = Tensor(np.array([10.0]).astype(np.float32))
|
||||||
parallel_mode = _get_parallel_mode()
|
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||||
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
|
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
|
||||||
self.reducer_flag = True
|
self.reducer_flag = True
|
||||||
if self.reducer_flag:
|
if self.reducer_flag:
|
||||||
mean = _get_gradients_mean()
|
self.grad_reducer = DistributedGradReducer(optimizer.parameters)
|
||||||
degree = _get_device_num()
|
|
||||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
|
||||||
|
|
||||||
def construct(self, data, label):
|
def construct(self, data, label):
|
||||||
weights = self.weights
|
weights = self.weights
|
||||||
|
|
Loading…
Reference in New Issue