!28743 fix rnns error

Merge pull request !28743 from 吕昱峰(Nate.River)/rnn_op
This commit is contained in:
i-robot 2022-01-10 01:59:50 +00:00 committed by Gitee
commit 3b00621bd3
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 7 additions and 4 deletions

View File

@ -181,11 +181,14 @@ class _DynamicRNNBase(Cell):
def construct(self, x, h, seq_length, w_ih, w_hh, b_ih, b_hh):
x_dtype = x.dtype
w_ih = w_ih.astype(x_dtype)
w_hh = w_hh.astype(x_dtype)
if b_ih is not None:
b_ih = b_ih.astype(x_dtype)
b_hh = b_ih.astype(x_dtype)
if seq_length is None:
return self.recurrent(x, h, w_ih.astype(x_dtype), w_hh.astype(x_dtype), \
b_ih.astype(x_dtype), b_hh.astype(x_dtype))
return self.variable_recurrent(x, h, seq_length, w_ih.astype(x_dtype), w_hh.astype(x_dtype), \
b_ih.astype(x_dtype), b_hh.astype(x_dtype))
return self.recurrent(x, h, w_ih, w_hh, b_ih, b_hh)
return self.variable_recurrent(x, h, seq_length, w_ih, w_hh, b_ih, b_hh)
class _DynamicRNNRelu(_DynamicRNNBase):