!20621 fix the bug with slightly lower accuracy of warpctc.
Merge pull request !20621 from 郑彬/warpctc_0720
This commit is contained in:
commit
8d999a91b6
|
@ -22,7 +22,6 @@ from mindspore.nn.cell import Cell
|
|||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.tensor import Tensor
|
||||
from src.model_utils.device_adapter import get_device_num
|
||||
|
||||
compute_norm = C.MultitypeFuncGraph("compute_norm")
|
||||
|
||||
|
@ -68,7 +67,7 @@ class TrainOneStepCellWithGradClip(Cell):
|
|||
Tensor, a scalar Tensor with shape :math:`()`.
|
||||
"""
|
||||
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
def __init__(self, network, optimizer, device_num, sens=1.0):
|
||||
super(TrainOneStepCellWithGradClip, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
|
@ -87,7 +86,7 @@ class TrainOneStepCellWithGradClip(Cell):
|
|||
self.cast = P.Cast()
|
||||
self.concat = P.Concat(axis=0)
|
||||
self.ten = Tensor(np.array([10.0]).astype(np.float32))
|
||||
if get_device_num() > 1:
|
||||
if device_num > 1:
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters)
|
||||
|
|
|
@ -145,7 +145,7 @@ def train():
|
|||
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum)
|
||||
|
||||
net = WithLossCell(net, loss)
|
||||
net = TrainOneStepCellWithGradClip(net, opt).set_train()
|
||||
net = TrainOneStepCellWithGradClip(net, opt, device_num).set_train()
|
||||
# define model
|
||||
model = Model(net)
|
||||
# define callbacks
|
||||
|
|
Loading…
Reference in New Issue