diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/affine_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/affine_infer.c index c0e6e91c05c..07ad84871ae 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/affine_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/affine_infer.c @@ -77,7 +77,10 @@ int AffineInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * bool del_start = false; bool del_end = false; if (a_shape_size == 1) { - ShapeInsert(a_shape, &a_shape_size, 0, 1); + int ret = ShapeInsert(a_shape, &a_shape_size, 0, 1); + if (ret != NNACL_OK) { + return NNACL_ERR; + } SetShapeArray(input0, a_shape, a_shape_size); del_start = true; } @@ -105,7 +108,10 @@ int AffineInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * } c_shape[c_shape_size - 1] = b_shape[b_shape_size - 1]; if (del_start) { - ShapeErase(c_shape, &c_shape_size, 0); + int erase_ret = ShapeErase(c_shape, &c_shape_size, 0); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } } if (del_end) { c_shape_size--; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/argmin_max_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/argmin_max_infer.c index dcba1a06d76..44cae261f29 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/argmin_max_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/argmin_max_infer.c @@ -62,7 +62,10 @@ int ArgMinMaxInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor return NNACL_PARAM_INVALID; } if (param->topk_ == 1 && !param->keep_dims_) { - ShapeErase(output_shape, &output_shape_size, axis); + int erase_ret = ShapeErase(output_shape, &output_shape_size, axis); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } } else { output_shape[axis] = param->topk_; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/common_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/common_infer.c index 5d111a6183c..7969d623685 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/common_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/common_infer.c @@ -149,20 +149,18 @@ int CheckAugmentWithMinSize(const TensorC *const *inputs, size_t inputs_size, Te return NNACL_OK; } -int SetShapeTensor(TensorC *dst, const TensorC *src) { +void SetShapeTensor(TensorC *dst, const TensorC *src) { for (size_t i = 0; i < src->shape_size_; i++) { dst->shape_[i] = src->shape_[i]; } dst->shape_size_ = src->shape_size_; - return NNACL_OK; } -int SetShapeArray(TensorC *dst, const int *src, size_t src_size) { +void SetShapeArray(TensorC *dst, const int *src, size_t src_size) { for (size_t i = 0; i < src_size; i++) { dst->shape_[i] = src[i]; } dst->shape_size_ = src_size; - return NNACL_OK; } void SetDataTypeFormat(TensorC *dst, const TensorC *src) { @@ -287,18 +285,16 @@ int GetDimensionSize(const TensorC *tensor, const size_t index) { return dim_size; } -int ShapeSet(int *dst_shape, size_t *dst_shape_size, const int *src_shape, size_t src_shape_size) { +void ShapeSet(int *dst_shape, size_t *dst_shape_size, const int *src_shape, size_t src_shape_size) { for (size_t i = 0; i < src_shape_size; i++) { dst_shape[i] = src_shape[i]; } *dst_shape_size = src_shape_size; - return NNACL_OK; } -int ShapePush(int *shape, size_t *shape_size, int value) { +void ShapePush(int *shape, size_t *shape_size, int value) { shape[*shape_size] = value; *shape_size = *shape_size + 1; - return NNACL_OK; } int ShapeInsert(int *shape, size_t *shape_size, int index, int value) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/common_infer.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/common_infer.h index 301b097ed3d..3b2a9e197b2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/common_infer.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/common_infer.h @@ -181,10 +181,10 @@ int CheckAugmentWithMinSize(const TensorC *const *inputs, size_t inputs_size, Te const OpParameter *parameter, size_t inputs_size_obj, size_t outputs_size_obj); void SetDataTypeFormat(TensorC *dst, const TensorC *src); -int SetShapeTensor(TensorC *dst, const TensorC *src); -int SetShapeArray(TensorC *dst, const int *src, size_t src_size); -int ShapeSet(int *dst_shape, size_t *dst_shape_size, const int *src_shape, size_t src_shape_size); -int ShapePush(int *shape, size_t *shape_size, int value); +void SetShapeTensor(TensorC *dst, const TensorC *src); +void SetShapeArray(TensorC *dst, const int *src, size_t src_size); +void ShapeSet(int *dst_shape, size_t *dst_shape_size, const int *src_shape, size_t src_shape_size); +void ShapePush(int *shape, size_t *shape_size, int value); int ShapeInsert(int *shape, size_t *shape_size, int index, int value); int ShapeErase(int *shape, size_t *shape_size, int index); bool ShapeEqual(const int *shape0, size_t shape0_size, const int *shape1, size_t shape1_size); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/concat_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/concat_infer.c index 04a32c701de..638e4a1a5fd 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/concat_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/concat_infer.c @@ -48,7 +48,10 @@ int ConcatInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * int input0_shape_without_axis[MAX_SHAPE_SIZE] = {0}; size_t input0_shape_without_axis_size = 0; ShapeSet(input0_shape_without_axis, &input0_shape_without_axis_size, input0_shape, input0_shape_size); - ShapeErase(input0_shape_without_axis, &input0_shape_without_axis_size, axis); + int erase_ret = ShapeErase(input0_shape_without_axis, &input0_shape_without_axis_size, axis); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } int output_axis_dim = input0_shape[axis]; for (size_t i = 1; i < inputs_size; ++i) { if (inputs[i]->shape_size_ != input0_shape_size) { @@ -63,7 +66,10 @@ int ConcatInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * return NNACL_PARAM_INVALID; } int axis_tmp = shape_tmp[axis]; - ShapeErase(shape_tmp, &shape_tmp_size, axis); + erase_ret = ShapeErase(shape_tmp, &shape_tmp_size, axis); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } if (!ShapeEqual(input0_shape_without_axis, input0_shape_without_axis_size, shape_tmp, shape_tmp_size)) { return NNACL_ERR; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/embedding_lookup_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/embedding_lookup_infer.c index e9e9286a450..bcaecf4c583 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/embedding_lookup_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/embedding_lookup_infer.c @@ -41,7 +41,10 @@ int EmbeddingLookupInferShape(const TensorC *const *inputs, size_t inputs_size, int embedding_shape[MAX_SHAPE_SIZE] = {0}; size_t embedding_shape_size = 0; ShapeSet(embedding_shape, &embedding_shape_size, params_->shape_, params_->shape_size_); - ShapeErase(embedding_shape, &embedding_shape_size, 0); + int erase_ret = ShapeErase(embedding_shape, &embedding_shape_size, 0); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } int output_shape[MAX_SHAPE_SIZE] = {0}; size_t output_shape_size = 0; ShapeSet(output_shape, &output_shape_size, ids->shape_, ids->shape_size_); @@ -55,7 +58,10 @@ int EmbeddingLookupInferShape(const TensorC *const *inputs, size_t inputs_size, int embedding_shape_t[MAX_SHAPE_SIZE] = {0}; size_t embedding_shape_t_size = 0; ShapeSet(embedding_shape_t, &embedding_shape_t_size, inputs[i]->shape_, inputs[i]->shape_size_); - ShapeErase(embedding_shape_t, &embedding_shape_t_size, 0); + erase_ret = ShapeErase(embedding_shape_t, &embedding_shape_t_size, 0); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } bool t_equal = ShapeEqual(embedding_shape_t, embedding_shape_t_size, embedding_shape, embedding_shape_size); if (!t_equal) { return NNACL_INPUT_TENSOR_ERROR; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/expand_dims_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/expand_dims_infer.c index 10dd49405e7..39ed749343b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/expand_dims_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/expand_dims_infer.c @@ -43,7 +43,10 @@ int ExpandDimsInferShape(const TensorC *const *inputs, size_t inputs_size, Tenso } ShapeSet(output->shape_, &(output->shape_size_), input->shape_, input->shape_size_); - ShapeInsert(output->shape_, &(output->shape_size_), dim, 1); + int ret = ShapeInsert(output->shape_, &(output->shape_size_), dim, 1); + if (ret != NNACL_OK) { + return NNACL_ERR; + } return NNACL_OK; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/gather_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/gather_infer.c index dc6b3cb98e7..b8ca877d4c2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/gather_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/gather_infer.c @@ -61,9 +61,15 @@ int GatherInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * int out_shape[MAX_SHAPE_SIZE] = {0}; size_t out_shape_size = 0; ShapeSet(out_shape, &out_shape_size, in_shape, in_shape_size); - ShapeErase(out_shape, &out_shape_size, axis); + int erase_ret = ShapeErase(out_shape, &out_shape_size, axis); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } for (int i = indices_rank - 1; i >= 0; --i) { - ShapeInsert(out_shape, &out_shape_size, axis, indices_shape[i]); + ret = ShapeInsert(out_shape, &out_shape_size, axis, indices_shape[i]); + if (ret != NNACL_OK) { + return NNACL_ERR; + } } SetShapeArray(output, out_shape, out_shape_size); return NNACL_OK; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/gru_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/gru_infer.c index 355203725d5..cd0887089d5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/gru_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/gru_infer.c @@ -68,9 +68,15 @@ int GruInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **ou GruParameter *param = (GruParameter *)parameter; if (param->bidirectional_) { - ShapeInsert(out_shape, &out_shape_size, 1, 2); + int ret = ShapeInsert(out_shape, &out_shape_size, 1, 2); + if (ret != NNACL_OK) { + return NNACL_ERR; + } } else { - ShapeInsert(out_shape, &out_shape_size, 1, 1); + int ret = ShapeInsert(out_shape, &out_shape_size, 1, 1); + if (ret != NNACL_OK) { + return NNACL_ERR; + } } SetShapeArray(output, out_shape, out_shape_size); // set hidden state diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/lstm_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/lstm_infer.c index d2de069ef62..1bd18420439 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/lstm_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/lstm_infer.c @@ -46,9 +46,15 @@ int LstmInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o ShapeSet(out_shape, &out_shape_size, input->shape_, input->shape_size_); out_shape[2] = hidden_size; if (param->bidirectional_) { - ShapeInsert(out_shape, &out_shape_size, 1, 2); + int ret = ShapeInsert(out_shape, &out_shape_size, 1, 2); + if (ret != NNACL_OK) { + return NNACL_ERR; + } } else { - ShapeInsert(out_shape, &out_shape_size, 1, 1); + int ret = ShapeInsert(out_shape, &out_shape_size, 1, 1); + if (ret != NNACL_OK) { + return NNACL_ERR; + } } SetShapeArray(output, out_shape, out_shape_size); int state_shape[MAX_SHAPE_SIZE]; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/matmul_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/matmul_infer.c index e80bd4e11d5..31f169c242d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/matmul_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/matmul_infer.c @@ -17,6 +17,31 @@ #include "nnacl/infer/matmul_infer.h" #include "nnacl/infer/infer_register.h" +int CheckMatmulInputShape(int *a_shape, size_t a_shape_size, int *b_shape, size_t b_shape_size, + MatMulParameter *param) { + for (size_t i = 0; i < (a_shape_size - 2) && i < (b_shape_size - 2); ++i) { + if (a_shape[i] != b_shape[i]) { + return NNACL_INPUT_TENSOR_ERROR; + } + } + if (param->a_transpose_) { + if (a_shape_size < 2) { + return NNACL_ERR; + } + iswap(&a_shape[a_shape_size - 1], &a_shape[a_shape_size - 2]); + } + if (param->b_transpose_) { + if (b_shape_size < 2) { + return NNACL_ERR; + } + iswap(&b_shape[b_shape_size - 1], &b_shape[b_shape_size - 2]); + } + if (a_shape[a_shape_size - 1] != b_shape[b_shape_size - 2]) { + return NNACL_ERR; + } + return NNACL_OK; +} + int MatmulInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, OpParameter *parameter) { int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 2, 3, 1); @@ -48,7 +73,10 @@ int MatmulInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * bool del_start = false; bool del_end = false; if (a_shape_size == 1) { - ShapeInsert(a_shape, &a_shape_size, 0, 1); + int insert_ret = ShapeInsert(a_shape, &a_shape_size, 0, 1); + if (insert_ret != NNACL_OK) { + return NNACL_ERR; + } SetShapeArray(input0, a_shape, a_shape_size); del_start = true; } @@ -57,25 +85,8 @@ int MatmulInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * SetShapeArray(input1, b_shape, b_shape_size); del_end = true; } - for (size_t i = 0; i < (a_shape_size - 2) && i < (b_shape_size - 2); ++i) { - if (a_shape[i] != b_shape[i]) { - return NNACL_INPUT_TENSOR_ERROR; - } - } - - if (param->a_transpose_) { - if (a_shape_size < 2) { - return NNACL_ERR; - } - iswap(&a_shape[a_shape_size - 1], &a_shape[a_shape_size - 2]); - } - if (param->b_transpose_) { - if (b_shape_size < 2) { - return NNACL_ERR; - } - iswap(&b_shape[b_shape_size - 1], &b_shape[b_shape_size - 2]); - } - if (a_shape[a_shape_size - 1] != b_shape[b_shape_size - 2]) { + int ret = CheckMatmulInputShape(a_shape, a_shape_size, b_shape, b_shape_size, param); + if (ret != NNACL_OK) { return NNACL_ERR; } int c_shape[MAX_SHAPE_SIZE]; @@ -86,7 +97,10 @@ int MatmulInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * } c_shape[c_shape_size - 1] = b_shape[b_shape_size - 1]; if (del_start) { - ShapeErase(c_shape, &c_shape_size, 0); + int erase_ret = ShapeErase(c_shape, &c_shape_size, 0); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } } if (del_end) { c_shape_size--; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/stack_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/stack_infer.c index 45754af1de1..d533441390d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/stack_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/stack_infer.c @@ -59,7 +59,10 @@ int StackInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC ** return NNACL_PARAM_INVALID; } } - ShapeInsert(output_shape, &output_shape_size, axis, inputs_size); + int insert_ret = ShapeInsert(output_shape, &output_shape_size, axis, inputs_size); + if (insert_ret != NNACL_OK) { + return NNACL_ERR; + } SetShapeArray(outputs[0], output_shape, output_shape_size); return NNACL_OK; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/strided_slice_grad_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/strided_slice_grad_infer.c index 8a21fc2237c..b4be741c3d4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/strided_slice_grad_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/strided_slice_grad_infer.c @@ -32,6 +32,25 @@ bool StridedSliceCheckInputs(const TensorC *const *inputs, size_t inputs_size) { return true; // note: the original code is ndim_ <= in_shape_size } +void ApplyBeginEndEllipsisMask(size_t ndim, int *begins, uint32_t *begins_mask, int *ends, uint32_t *ends_mask, + uint32_t *ellipsis_mask, int *in_shape) { + for (size_t i = 0; i < ndim; i++) { + if (begins_mask[i]) { + begins[i] = 0; + } + if (ends_mask[i]) { + ends[i] = in_shape[i]; + } + } + for (size_t i = 0; i < ndim; i++) { + if (ellipsis_mask[i]) { + begins[i] = 0; + ends[i] = in_shape[i]; + break; + } + } +} + int StridedSliceGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, OpParameter *parameter) { int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 5, 1); @@ -97,7 +116,10 @@ int StridedSliceGradInferShape(const TensorC *const *inputs, size_t inputs_size, for (size_t i = 0; i < ndim_; i++) { if (new_axis_mask_[i]) { ndim_ += 1; - ShapeInsert(in_shape_, &in_shape_size, i, 1); + int ret = ShapeInsert(in_shape_, &in_shape_size, i, 1); + if (ret != NNACL_OK) { + return NNACL_ERR; + } begins_[i] = 0; ends_[i] = 1; strides_[i] = 1; @@ -111,27 +133,7 @@ int StridedSliceGradInferShape(const TensorC *const *inputs, size_t inputs_size, ellipsis_mask_[i] = false; } } - // ApplyBeginMask; - for (size_t i = 0; i < ndim_; i++) { - if (begins_mask_[i]) { - begins_[i] = 0; - } - } - // ApplyEndMask; - for (size_t i = 0; i < ndim_; i++) { - if (ends_mask_[i]) { - ends_[i] = in_shape_[i]; - } - } - // ApplyEllipsisMask; - for (size_t i = 0; i < ndim_; i++) { - if (ellipsis_mask_[i]) { - begins_[i] = 0; - ends_[i] = in_shape_[i]; - break; - } - } - + ApplyBeginEndEllipsisMask(ndim_, begins_, begins_mask_, ends_, ends_mask_, ellipsis_mask_, in_shape_); if (!inferflag) { return NNACL_OK; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/strided_slice_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/strided_slice_infer.c index c8395d86697..442d95624d3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/strided_slice_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/strided_slice_infer.c @@ -206,7 +206,10 @@ int ApplyNewAxisMask(StridedSliceTransferBuffer *transfer_buffer, StridedSlicePa if (*out_shape_size >= MAX_SHAPE_SIZE) { return NNACL_ERR; } - ShapeInsert(in_shape, out_shape_size, i, 1); + int ret = ShapeInsert(in_shape, out_shape_size, i, 1); + if (ret != NNACL_OK) { + return NNACL_ERR; + } transfer_buffer->begins_[i] = 0; transfer_buffer->ends_[i] = 1; transfer_buffer->strides_[i] = 1; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_fromtensor_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_fromtensor_infer.c index 9788be3c1b6..d2cf972edb9 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_fromtensor_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_fromtensor_infer.c @@ -66,7 +66,10 @@ int TensorListFromTensorInferShape(const TensorC *const *inputs, size_t inputs_s ShapeSet(output->element_shape_, &(output->element_shape_size_), ele_shape_ptr, GetElementNum(input1)); output->element_num_ = dim0; - MallocTensorListData(output, input0->data_type_, &tensor_shape); + int ret = MallocTensorListData(output, input0->data_type_, &tensor_shape); + if (ret != NNACL_OK) { + return NNACL_ERR; + } free(tensor_shape.shape_); free(tensor_shape.shape_size_); return NNACL_OK; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_reserve_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_reserve_infer.c index d7b08654638..6827db30cb9 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_reserve_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_reserve_infer.c @@ -73,7 +73,10 @@ int TensorListReserveInferShape(const TensorC *const *inputs, size_t inputs_size tmp_shape.shape_size_[i] = 0; tmp_shape.shape_[i] = NULL; } - MallocTensorListData(output, kTypeUnknown, &tmp_shape); + int ret = MallocTensorListData(output, kTypeUnknown, &tmp_shape); + if (ret != NNACL_OK) { + return NNACL_ERR; + } free(tmp_shape.shape_size_); free(tmp_shape.shape_); return NNACL_OK; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_setitem_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_setitem_infer.c index 7d9f5ff1e41..495f0609523 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_setitem_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_setitem_infer.c @@ -108,7 +108,10 @@ int TensorListSetItemInferShape(const TensorC *const *inputs, size_t inputs_size out_shape.shape_[index] = (int *)(value_tensor->shape_); out_shape.shape_size_[index] = value_tensor->shape_size_; - MallocTensorListData(output0, input0->tensors_data_type_, &out_shape); + int ret = MallocTensorListData(output0, input0->tensors_data_type_, &out_shape); + if (ret != NNACL_OK) { + return NNACL_ERR; + } free(out_shape.shape_); free(out_shape.shape_size_); return NNACL_OK; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_stack_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_stack_infer.c index 7059cdd9f52..07634be77b6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_stack_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_stack_infer.c @@ -83,7 +83,10 @@ int TensorListStackInferShape(const TensorC *const *inputs, size_t inputs_size, if (output_shape_size >= MAX_SHAPE_SIZE) { return NNACL_ERR; } - ShapeInsert(output_shape, &output_shape_size, 0, input0->element_num_); + int ret = ShapeInsert(output_shape, &output_shape_size, 0, input0->element_num_); + if (ret != NNACL_OK) { + return NNACL_ERR; + } SetShapeArray(output, output_shape, output_shape_size); return NNACL_OK; }