forked from mindspore-Ecosystem/mindspore
!15288 Fix check string attr of DynamicRnn.
From: @liu_xiao_93 Reviewed-by: @liangchenghui,@wuxuejian Signed-off-by: @liangchenghui
This commit is contained in:
commit
721bcca85b
|
@ -7261,8 +7261,11 @@ class DynamicRNN(PrimitiveWithInfer):
|
|||
self.use_peephole = validator.check_value_type("use_peephole", use_peephole, [bool], self.name)
|
||||
self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name)
|
||||
self.is_training = validator.check_value_type("is_training", is_training, [bool], self.name)
|
||||
validator.check_value_type("cell_type", cell_type, [str], self.name)
|
||||
self.cell_type = validator.check_string(cell_type, ['LSTM'], "cell_type", self.name)
|
||||
validator.check_value_type("direction", direction, [str], self.name)
|
||||
self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name)
|
||||
validator.check_value_type("activation", activation, [str], self.name)
|
||||
self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
|
||||
|
||||
def infer_shape(self, x_shape, w_shape, b_shape, seq_shape, h_shape, c_shape):
|
||||
|
|
Loading…
Reference in New Issue