!20621 fix the bug with slightly lower accuracy of warpctc.

Merge pull request !20621 from 郑彬/warpctc_0720
This commit is contained in:
i-robot 2021-07-23 01:27:26 +00:00 committed by Gitee
commit 8d999a91b6
2 changed files with 3 additions and 4 deletions

View File

@ -22,7 +22,6 @@ from mindspore.nn.cell import Cell
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from src.model_utils.device_adapter import get_device_num
compute_norm = C.MultitypeFuncGraph("compute_norm")
@ -68,7 +67,7 @@ class TrainOneStepCellWithGradClip(Cell):
Tensor, a scalar Tensor with shape :math:`()`.
"""
def __init__(self, network, optimizer, sens=1.0):
def __init__(self, network, optimizer, device_num, sens=1.0):
super(TrainOneStepCellWithGradClip, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
@ -87,7 +86,7 @@ class TrainOneStepCellWithGradClip(Cell):
self.cast = P.Cast()
self.concat = P.Concat(axis=0)
self.ten = Tensor(np.array([10.0]).astype(np.float32))
if get_device_num() > 1:
if device_num > 1:
self.reducer_flag = True
if self.reducer_flag:
self.grad_reducer = DistributedGradReducer(optimizer.parameters)

View File

@ -145,7 +145,7 @@ def train():
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum)
net = WithLossCell(net, loss)
net = TrainOneStepCellWithGradClip(net, opt).set_train()
net = TrainOneStepCellWithGradClip(net, opt, device_num).set_train()
# define model
model = Model(net)
# define callbacks