!28822 fix codex infer

Merge pull request !28822 from zhaodezan/master
This commit is contained in:
i-robot 2022-01-17 07:02:32 +00:00 committed by Gitee
commit 82485ca4d8
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 6 additions and 6 deletions

View File

@ -108,7 +108,7 @@ int LstmInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o
if (outputs_size > DIMENSION_4D) {
int intermediate_states_shape[MAX_SHAPE_SIZE];
size_t intermediate_states_shape_size = 1;
const size_t intermediate_states_shape_size = 1;
int batch_size = input->shape_[SECOND_INPUT];
int seq_len = input->shape_[FIRST_INPUT];
intermediate_states_shape[FIRST_INPUT] = no_of_recorde_values * batch_size * hidden_size * seq_len * dir_multiplier;

View File

@ -20,7 +20,7 @@
#define MIN_SHAPE_SIZE 2
int CheckMatmulInputShape(int *a_shape, size_t a_shape_size, int *b_shape, size_t b_shape_size, int *bias_shape,
int CheckMatmulInputShape(int *a_shape, size_t a_shape_size, int *b_shape, size_t b_shape_size, const int *bias_shape,
size_t bias_shape_size, const MatMulParameter *param) {
if (a_shape_size < MIN_SHAPE_SIZE || b_shape_size < MIN_SHAPE_SIZE) {
return NNACL_PARAM_INVALID;
@ -36,13 +36,13 @@ int CheckMatmulInputShape(int *a_shape, size_t a_shape_size, int *b_shape, size_
}
}
if (param->a_transpose_) {
if (a_shape_size < 2) {
if (a_shape_size < MIN_SHAPE_SIZE) {
return NNACL_ERR;
}
iswap(&a_shape[a_shape_size - 1], &a_shape[a_shape_size - 2]);
iswap(&a_shape[a_shape_size - 1], &a_shape[a_shape_size - DIMENSION_2D]);
}
if (param->b_transpose_) {
if (b_shape_size < 2) {
if (b_shape_size < MIN_SHAPE_SIZE) {
return NNACL_ERR;
}
iswap(&b_shape[b_shape_size - 1], &b_shape[b_shape_size - 2]);
@ -76,7 +76,7 @@ int SetShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs
MS_CHECK_TRUE_RET(bias_shape_size == b_shape_size || bias_shape_size == DIMENSION_1D, NNACL_ERR);
}
if (a_shape_size == 4 && a_shape[2] == 1 && a_shape[3] == 1) {
if (a_shape_size == COMM_SHAPE_SIZE && a_shape[THIRD_INPUT] == 1 && a_shape[FOURTH_INPUT] == 1) {
a_shape_size = 2;
SetShapeArray(input0, a_shape, a_shape_size);
}