!28743 fix rnns error
Merge pull request !28743 from 吕昱峰(Nate.River)/rnn_op
This commit is contained in:
commit
3b00621bd3
|
@ -181,11 +181,14 @@ class _DynamicRNNBase(Cell):
|
|||
|
||||
def construct(self, x, h, seq_length, w_ih, w_hh, b_ih, b_hh):
|
||||
x_dtype = x.dtype
|
||||
w_ih = w_ih.astype(x_dtype)
|
||||
w_hh = w_hh.astype(x_dtype)
|
||||
if b_ih is not None:
|
||||
b_ih = b_ih.astype(x_dtype)
|
||||
b_hh = b_ih.astype(x_dtype)
|
||||
if seq_length is None:
|
||||
return self.recurrent(x, h, w_ih.astype(x_dtype), w_hh.astype(x_dtype), \
|
||||
b_ih.astype(x_dtype), b_hh.astype(x_dtype))
|
||||
return self.variable_recurrent(x, h, seq_length, w_ih.astype(x_dtype), w_hh.astype(x_dtype), \
|
||||
b_ih.astype(x_dtype), b_hh.astype(x_dtype))
|
||||
return self.recurrent(x, h, w_ih, w_hh, b_ih, b_hh)
|
||||
return self.variable_recurrent(x, h, seq_length, w_ih, w_hh, b_ih, b_hh)
|
||||
|
||||
|
||||
class _DynamicRNNRelu(_DynamicRNNBase):
|
||||
|
|
Loading…
Reference in New Issue