dims check

This commit is contained in:
wilfChen 2020-05-09 14:23:41 +08:00
parent bab6e0f549
commit 8eba2c5147
1 changed files with 2 additions and 1 deletions

View File

@ -2157,8 +2157,9 @@ class LSTM(PrimitiveWithInfer):
self.num_directions = 1
def infer_shape(self, x_shape, h_shape, c_shape, w_shape):
# (batch, seq, feature)
# (seq, batch_size, feature)
validator.check_integer("x rank", len(x_shape), 3, Rel.EQ, self.name)
validator.check_integer("x[2]", x_shape[2], self.input_size, Rel.EQ, self.name)
# h and c should be same shape
validator.check_integer("h rank", len(h_shape), 3, Rel.EQ, self.name)