forked from mindspore-Ecosystem/mindspore
dims check
This commit is contained in:
parent
bab6e0f549
commit
8eba2c5147
|
@ -2157,8 +2157,9 @@ class LSTM(PrimitiveWithInfer):
|
|||
self.num_directions = 1
|
||||
|
||||
def infer_shape(self, x_shape, h_shape, c_shape, w_shape):
|
||||
# (batch, seq, feature)
|
||||
# (seq, batch_size, feature)
|
||||
validator.check_integer("x rank", len(x_shape), 3, Rel.EQ, self.name)
|
||||
validator.check_integer("x[2]", x_shape[2], self.input_size, Rel.EQ, self.name)
|
||||
|
||||
# h and c should be same shape
|
||||
validator.check_integer("h rank", len(h_shape), 3, Rel.EQ, self.name)
|
||||
|
|
Loading…
Reference in New Issue