forked from mindspore-Ecosystem/mindspore
align the datatype of lstm input in warpctc
This commit is contained in:
parent
d8a7fd8801
commit
846d8691da
|
@ -61,7 +61,6 @@ class StackedRNN(nn.Cell):
|
|||
self.fc_weight = Tensor(np.random.random((hidden_size, num_class)).astype(np.float16))
|
||||
self.fc_bias = Tensor(np.random.random(self.num_class).astype(np.float16))
|
||||
|
||||
self.cast = P.Cast()
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
self.matmul = nn.MatMul()
|
||||
|
@ -118,6 +117,7 @@ class StackedRNNForGPU(nn.Cell):
|
|||
self.transpose = P.Transpose()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.cast(x, mstype.float32)
|
||||
x = self.transpose(x, (3, 0, 2, 1))
|
||||
x = self.reshape(x, (-1, self.batch_size, self.input_size))
|
||||
output, _ = self.lstm(x, (self.h, self.c))
|
||||
|
|
Loading…
Reference in New Issue