!28743 fix rnns error
Merge pull request !28743 from 吕昱峰(Nate.River)/rnn_op
This commit is contained in:
commit
3b00621bd3
|
@ -181,11 +181,14 @@ class _DynamicRNNBase(Cell):
|
||||||
|
|
||||||
def construct(self, x, h, seq_length, w_ih, w_hh, b_ih, b_hh):
|
def construct(self, x, h, seq_length, w_ih, w_hh, b_ih, b_hh):
|
||||||
x_dtype = x.dtype
|
x_dtype = x.dtype
|
||||||
|
w_ih = w_ih.astype(x_dtype)
|
||||||
|
w_hh = w_hh.astype(x_dtype)
|
||||||
|
if b_ih is not None:
|
||||||
|
b_ih = b_ih.astype(x_dtype)
|
||||||
|
b_hh = b_ih.astype(x_dtype)
|
||||||
if seq_length is None:
|
if seq_length is None:
|
||||||
return self.recurrent(x, h, w_ih.astype(x_dtype), w_hh.astype(x_dtype), \
|
return self.recurrent(x, h, w_ih, w_hh, b_ih, b_hh)
|
||||||
b_ih.astype(x_dtype), b_hh.astype(x_dtype))
|
return self.variable_recurrent(x, h, seq_length, w_ih, w_hh, b_ih, b_hh)
|
||||||
return self.variable_recurrent(x, h, seq_length, w_ih.astype(x_dtype), w_hh.astype(x_dtype), \
|
|
||||||
b_ih.astype(x_dtype), b_hh.astype(x_dtype))
|
|
||||||
|
|
||||||
|
|
||||||
class _DynamicRNNRelu(_DynamicRNNBase):
|
class _DynamicRNNRelu(_DynamicRNNBase):
|
||||||
|
|
Loading…
Reference in New Issue