remove inner api for warpctc

This commit is contained in:
gengdongjie 2021-06-03 15:43:57 +08:00
parent f55866e537
commit 4c0ad3acc4
1 changed files with 3 additions and 8 deletions

View File

@ -14,9 +14,6 @@
# ============================================================================
"""Automatic differentiation with grad clip."""
import numpy as np
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
_get_parallel_mode)
from mindspore.context import ParallelMode
from mindspore.common import dtype as mstype
from mindspore.ops import composite as C
from mindspore.ops import functional as F
@ -25,6 +22,7 @@ from mindspore.nn.cell import Cell
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from src.model_utils.device_adapter import get_device_num
compute_norm = C.MultitypeFuncGraph("compute_norm")
@ -89,13 +87,10 @@ class TrainOneStepCellWithGradClip(Cell):
self.cast = P.Cast()
self.concat = P.Concat(axis=0)
self.ten = Tensor(np.array([10.0]).astype(np.float32))
parallel_mode = _get_parallel_mode()
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
if get_device_num() > 1:
self.reducer_flag = True
if self.reducer_flag:
mean = _get_gradients_mean()
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.grad_reducer = DistributedGradReducer(optimizer.parameters)
def construct(self, data, label):
weights = self.weights