forked from mindspore-Ecosystem/mindspore
commit
82485ca4d8
|
@ -108,7 +108,7 @@ int LstmInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o
|
|||
|
||||
if (outputs_size > DIMENSION_4D) {
|
||||
int intermediate_states_shape[MAX_SHAPE_SIZE];
|
||||
size_t intermediate_states_shape_size = 1;
|
||||
const size_t intermediate_states_shape_size = 1;
|
||||
int batch_size = input->shape_[SECOND_INPUT];
|
||||
int seq_len = input->shape_[FIRST_INPUT];
|
||||
intermediate_states_shape[FIRST_INPUT] = no_of_recorde_values * batch_size * hidden_size * seq_len * dir_multiplier;
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
|
||||
#define MIN_SHAPE_SIZE 2
|
||||
|
||||
int CheckMatmulInputShape(int *a_shape, size_t a_shape_size, int *b_shape, size_t b_shape_size, int *bias_shape,
|
||||
int CheckMatmulInputShape(int *a_shape, size_t a_shape_size, int *b_shape, size_t b_shape_size, const int *bias_shape,
|
||||
size_t bias_shape_size, const MatMulParameter *param) {
|
||||
if (a_shape_size < MIN_SHAPE_SIZE || b_shape_size < MIN_SHAPE_SIZE) {
|
||||
return NNACL_PARAM_INVALID;
|
||||
|
@ -36,13 +36,13 @@ int CheckMatmulInputShape(int *a_shape, size_t a_shape_size, int *b_shape, size_
|
|||
}
|
||||
}
|
||||
if (param->a_transpose_) {
|
||||
if (a_shape_size < 2) {
|
||||
if (a_shape_size < MIN_SHAPE_SIZE) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
iswap(&a_shape[a_shape_size - 1], &a_shape[a_shape_size - 2]);
|
||||
iswap(&a_shape[a_shape_size - 1], &a_shape[a_shape_size - DIMENSION_2D]);
|
||||
}
|
||||
if (param->b_transpose_) {
|
||||
if (b_shape_size < 2) {
|
||||
if (b_shape_size < MIN_SHAPE_SIZE) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
iswap(&b_shape[b_shape_size - 1], &b_shape[b_shape_size - 2]);
|
||||
|
@ -76,7 +76,7 @@ int SetShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs
|
|||
MS_CHECK_TRUE_RET(bias_shape_size == b_shape_size || bias_shape_size == DIMENSION_1D, NNACL_ERR);
|
||||
}
|
||||
|
||||
if (a_shape_size == 4 && a_shape[2] == 1 && a_shape[3] == 1) {
|
||||
if (a_shape_size == COMM_SHAPE_SIZE && a_shape[THIRD_INPUT] == 1 && a_shape[FOURTH_INPUT] == 1) {
|
||||
a_shape_size = 2;
|
||||
SetShapeArray(input0, a_shape, a_shape_size);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue