Fix check string attr of DynamicRnn.

This commit is contained in:
liuxiao93 2021-04-16 19:00:14 +08:00
parent 2edff12952
commit e895365c31
1 changed files with 3 additions and 0 deletions

View File

@ -7261,8 +7261,11 @@ class DynamicRNN(PrimitiveWithInfer):
self.use_peephole = validator.check_value_type("use_peephole", use_peephole, [bool], self.name)
self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name)
self.is_training = validator.check_value_type("is_training", is_training, [bool], self.name)
validator.check_value_type("cell_type", cell_type, [str], self.name)
self.cell_type = validator.check_string(cell_type, ['LSTM'], "cell_type", self.name)
validator.check_value_type("direction", direction, [str], self.name)
self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name)
validator.check_value_type("activation", activation, [str], self.name)
self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
def infer_shape(self, x_shape, w_shape, b_shape, seq_shape, h_shape, c_shape):