!21043 fix infer function

Merge pull request !21043 from zhaodezan/master
This commit is contained in:
i-robot 2021-07-30 03:24:24 +00:00 committed by Gitee
commit c6dc35f510
18 changed files with 147 additions and 75 deletions

View File

@ -77,7 +77,10 @@ int AffineInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
bool del_start = false;
bool del_end = false;
if (a_shape_size == 1) {
ShapeInsert(a_shape, &a_shape_size, 0, 1);
int ret = ShapeInsert(a_shape, &a_shape_size, 0, 1);
if (ret != NNACL_OK) {
return NNACL_ERR;
}
SetShapeArray(input0, a_shape, a_shape_size);
del_start = true;
}
@ -105,7 +108,10 @@ int AffineInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
}
c_shape[c_shape_size - 1] = b_shape[b_shape_size - 1];
if (del_start) {
ShapeErase(c_shape, &c_shape_size, 0);
int erase_ret = ShapeErase(c_shape, &c_shape_size, 0);
if (erase_ret != NNACL_OK) {
return NNACL_ERR;
}
}
if (del_end) {
c_shape_size--;

View File

@ -62,7 +62,10 @@ int ArgMinMaxInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor
return NNACL_PARAM_INVALID;
}
if (param->topk_ == 1 && !param->keep_dims_) {
ShapeErase(output_shape, &output_shape_size, axis);
int erase_ret = ShapeErase(output_shape, &output_shape_size, axis);
if (erase_ret != NNACL_OK) {
return NNACL_ERR;
}
} else {
output_shape[axis] = param->topk_;
}

View File

@ -149,20 +149,18 @@ int CheckAugmentWithMinSize(const TensorC *const *inputs, size_t inputs_size, Te
return NNACL_OK;
}
int SetShapeTensor(TensorC *dst, const TensorC *src) {
void SetShapeTensor(TensorC *dst, const TensorC *src) {
for (size_t i = 0; i < src->shape_size_; i++) {
dst->shape_[i] = src->shape_[i];
}
dst->shape_size_ = src->shape_size_;
return NNACL_OK;
}
int SetShapeArray(TensorC *dst, const int *src, size_t src_size) {
void SetShapeArray(TensorC *dst, const int *src, size_t src_size) {
for (size_t i = 0; i < src_size; i++) {
dst->shape_[i] = src[i];
}
dst->shape_size_ = src_size;
return NNACL_OK;
}
void SetDataTypeFormat(TensorC *dst, const TensorC *src) {
@ -287,18 +285,16 @@ int GetDimensionSize(const TensorC *tensor, const size_t index) {
return dim_size;
}
int ShapeSet(int *dst_shape, size_t *dst_shape_size, const int *src_shape, size_t src_shape_size) {
void ShapeSet(int *dst_shape, size_t *dst_shape_size, const int *src_shape, size_t src_shape_size) {
for (size_t i = 0; i < src_shape_size; i++) {
dst_shape[i] = src_shape[i];
}
*dst_shape_size = src_shape_size;
return NNACL_OK;
}
int ShapePush(int *shape, size_t *shape_size, int value) {
void ShapePush(int *shape, size_t *shape_size, int value) {
shape[*shape_size] = value;
*shape_size = *shape_size + 1;
return NNACL_OK;
}
int ShapeInsert(int *shape, size_t *shape_size, int index, int value) {

View File

@ -181,10 +181,10 @@ int CheckAugmentWithMinSize(const TensorC *const *inputs, size_t inputs_size, Te
const OpParameter *parameter, size_t inputs_size_obj, size_t outputs_size_obj);
void SetDataTypeFormat(TensorC *dst, const TensorC *src);
int SetShapeTensor(TensorC *dst, const TensorC *src);
int SetShapeArray(TensorC *dst, const int *src, size_t src_size);
int ShapeSet(int *dst_shape, size_t *dst_shape_size, const int *src_shape, size_t src_shape_size);
int ShapePush(int *shape, size_t *shape_size, int value);
void SetShapeTensor(TensorC *dst, const TensorC *src);
void SetShapeArray(TensorC *dst, const int *src, size_t src_size);
void ShapeSet(int *dst_shape, size_t *dst_shape_size, const int *src_shape, size_t src_shape_size);
void ShapePush(int *shape, size_t *shape_size, int value);
int ShapeInsert(int *shape, size_t *shape_size, int index, int value);
int ShapeErase(int *shape, size_t *shape_size, int index);
bool ShapeEqual(const int *shape0, size_t shape0_size, const int *shape1, size_t shape1_size);

View File

@ -48,7 +48,10 @@ int ConcatInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
int input0_shape_without_axis[MAX_SHAPE_SIZE] = {0};
size_t input0_shape_without_axis_size = 0;
ShapeSet(input0_shape_without_axis, &input0_shape_without_axis_size, input0_shape, input0_shape_size);
ShapeErase(input0_shape_without_axis, &input0_shape_without_axis_size, axis);
int erase_ret = ShapeErase(input0_shape_without_axis, &input0_shape_without_axis_size, axis);
if (erase_ret != NNACL_OK) {
return NNACL_ERR;
}
int output_axis_dim = input0_shape[axis];
for (size_t i = 1; i < inputs_size; ++i) {
if (inputs[i]->shape_size_ != input0_shape_size) {
@ -63,7 +66,10 @@ int ConcatInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
return NNACL_PARAM_INVALID;
}
int axis_tmp = shape_tmp[axis];
ShapeErase(shape_tmp, &shape_tmp_size, axis);
erase_ret = ShapeErase(shape_tmp, &shape_tmp_size, axis);
if (erase_ret != NNACL_OK) {
return NNACL_ERR;
}
if (!ShapeEqual(input0_shape_without_axis, input0_shape_without_axis_size, shape_tmp, shape_tmp_size)) {
return NNACL_ERR;
}

View File

@ -41,7 +41,10 @@ int EmbeddingLookupInferShape(const TensorC *const *inputs, size_t inputs_size,
int embedding_shape[MAX_SHAPE_SIZE] = {0};
size_t embedding_shape_size = 0;
ShapeSet(embedding_shape, &embedding_shape_size, params_->shape_, params_->shape_size_);
ShapeErase(embedding_shape, &embedding_shape_size, 0);
int erase_ret = ShapeErase(embedding_shape, &embedding_shape_size, 0);
if (erase_ret != NNACL_OK) {
return NNACL_ERR;
}
int output_shape[MAX_SHAPE_SIZE] = {0};
size_t output_shape_size = 0;
ShapeSet(output_shape, &output_shape_size, ids->shape_, ids->shape_size_);
@ -55,7 +58,10 @@ int EmbeddingLookupInferShape(const TensorC *const *inputs, size_t inputs_size,
int embedding_shape_t[MAX_SHAPE_SIZE] = {0};
size_t embedding_shape_t_size = 0;
ShapeSet(embedding_shape_t, &embedding_shape_t_size, inputs[i]->shape_, inputs[i]->shape_size_);
ShapeErase(embedding_shape_t, &embedding_shape_t_size, 0);
erase_ret = ShapeErase(embedding_shape_t, &embedding_shape_t_size, 0);
if (erase_ret != NNACL_OK) {
return NNACL_ERR;
}
bool t_equal = ShapeEqual(embedding_shape_t, embedding_shape_t_size, embedding_shape, embedding_shape_size);
if (!t_equal) {
return NNACL_INPUT_TENSOR_ERROR;

View File

@ -43,7 +43,10 @@ int ExpandDimsInferShape(const TensorC *const *inputs, size_t inputs_size, Tenso
}
ShapeSet(output->shape_, &(output->shape_size_), input->shape_, input->shape_size_);
ShapeInsert(output->shape_, &(output->shape_size_), dim, 1);
int ret = ShapeInsert(output->shape_, &(output->shape_size_), dim, 1);
if (ret != NNACL_OK) {
return NNACL_ERR;
}
return NNACL_OK;
}

View File

@ -61,9 +61,15 @@ int GatherInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
int out_shape[MAX_SHAPE_SIZE] = {0};
size_t out_shape_size = 0;
ShapeSet(out_shape, &out_shape_size, in_shape, in_shape_size);
ShapeErase(out_shape, &out_shape_size, axis);
int erase_ret = ShapeErase(out_shape, &out_shape_size, axis);
if (erase_ret != NNACL_OK) {
return NNACL_ERR;
}
for (int i = indices_rank - 1; i >= 0; --i) {
ShapeInsert(out_shape, &out_shape_size, axis, indices_shape[i]);
ret = ShapeInsert(out_shape, &out_shape_size, axis, indices_shape[i]);
if (ret != NNACL_OK) {
return NNACL_ERR;
}
}
SetShapeArray(output, out_shape, out_shape_size);
return NNACL_OK;

View File

@ -68,9 +68,15 @@ int GruInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **ou
GruParameter *param = (GruParameter *)parameter;
if (param->bidirectional_) {
ShapeInsert(out_shape, &out_shape_size, 1, 2);
int ret = ShapeInsert(out_shape, &out_shape_size, 1, 2);
if (ret != NNACL_OK) {
return NNACL_ERR;
}
} else {
ShapeInsert(out_shape, &out_shape_size, 1, 1);
int ret = ShapeInsert(out_shape, &out_shape_size, 1, 1);
if (ret != NNACL_OK) {
return NNACL_ERR;
}
}
SetShapeArray(output, out_shape, out_shape_size);
// set hidden state

View File

@ -46,9 +46,15 @@ int LstmInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o
ShapeSet(out_shape, &out_shape_size, input->shape_, input->shape_size_);
out_shape[2] = hidden_size;
if (param->bidirectional_) {
ShapeInsert(out_shape, &out_shape_size, 1, 2);
int ret = ShapeInsert(out_shape, &out_shape_size, 1, 2);
if (ret != NNACL_OK) {
return NNACL_ERR;
}
} else {
ShapeInsert(out_shape, &out_shape_size, 1, 1);
int ret = ShapeInsert(out_shape, &out_shape_size, 1, 1);
if (ret != NNACL_OK) {
return NNACL_ERR;
}
}
SetShapeArray(output, out_shape, out_shape_size);
int state_shape[MAX_SHAPE_SIZE];

View File

@ -17,6 +17,31 @@
#include "nnacl/infer/matmul_infer.h"
#include "nnacl/infer/infer_register.h"
int CheckMatmulInputShape(int *a_shape, size_t a_shape_size, int *b_shape, size_t b_shape_size,
MatMulParameter *param) {
for (size_t i = 0; i < (a_shape_size - 2) && i < (b_shape_size - 2); ++i) {
if (a_shape[i] != b_shape[i]) {
return NNACL_INPUT_TENSOR_ERROR;
}
}
if (param->a_transpose_) {
if (a_shape_size < 2) {
return NNACL_ERR;
}
iswap(&a_shape[a_shape_size - 1], &a_shape[a_shape_size - 2]);
}
if (param->b_transpose_) {
if (b_shape_size < 2) {
return NNACL_ERR;
}
iswap(&b_shape[b_shape_size - 1], &b_shape[b_shape_size - 2]);
}
if (a_shape[a_shape_size - 1] != b_shape[b_shape_size - 2]) {
return NNACL_ERR;
}
return NNACL_OK;
}
int MatmulInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 2, 3, 1);
@ -48,7 +73,10 @@ int MatmulInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
bool del_start = false;
bool del_end = false;
if (a_shape_size == 1) {
ShapeInsert(a_shape, &a_shape_size, 0, 1);
int insert_ret = ShapeInsert(a_shape, &a_shape_size, 0, 1);
if (insert_ret != NNACL_OK) {
return NNACL_ERR;
}
SetShapeArray(input0, a_shape, a_shape_size);
del_start = true;
}
@ -57,25 +85,8 @@ int MatmulInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
SetShapeArray(input1, b_shape, b_shape_size);
del_end = true;
}
for (size_t i = 0; i < (a_shape_size - 2) && i < (b_shape_size - 2); ++i) {
if (a_shape[i] != b_shape[i]) {
return NNACL_INPUT_TENSOR_ERROR;
}
}
if (param->a_transpose_) {
if (a_shape_size < 2) {
return NNACL_ERR;
}
iswap(&a_shape[a_shape_size - 1], &a_shape[a_shape_size - 2]);
}
if (param->b_transpose_) {
if (b_shape_size < 2) {
return NNACL_ERR;
}
iswap(&b_shape[b_shape_size - 1], &b_shape[b_shape_size - 2]);
}
if (a_shape[a_shape_size - 1] != b_shape[b_shape_size - 2]) {
int ret = CheckMatmulInputShape(a_shape, a_shape_size, b_shape, b_shape_size, param);
if (ret != NNACL_OK) {
return NNACL_ERR;
}
int c_shape[MAX_SHAPE_SIZE];
@ -86,7 +97,10 @@ int MatmulInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
}
c_shape[c_shape_size - 1] = b_shape[b_shape_size - 1];
if (del_start) {
ShapeErase(c_shape, &c_shape_size, 0);
int erase_ret = ShapeErase(c_shape, &c_shape_size, 0);
if (erase_ret != NNACL_OK) {
return NNACL_ERR;
}
}
if (del_end) {
c_shape_size--;

View File

@ -59,7 +59,10 @@ int StackInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **
return NNACL_PARAM_INVALID;
}
}
ShapeInsert(output_shape, &output_shape_size, axis, inputs_size);
int insert_ret = ShapeInsert(output_shape, &output_shape_size, axis, inputs_size);
if (insert_ret != NNACL_OK) {
return NNACL_ERR;
}
SetShapeArray(outputs[0], output_shape, output_shape_size);
return NNACL_OK;
}

View File

@ -32,6 +32,25 @@ bool StridedSliceCheckInputs(const TensorC *const *inputs, size_t inputs_size) {
return true; // note: the original code is ndim_ <= in_shape_size
}
void ApplyBeginEndEllipsisMask(size_t ndim, int *begins, uint32_t *begins_mask, int *ends, uint32_t *ends_mask,
uint32_t *ellipsis_mask, int *in_shape) {
for (size_t i = 0; i < ndim; i++) {
if (begins_mask[i]) {
begins[i] = 0;
}
if (ends_mask[i]) {
ends[i] = in_shape[i];
}
}
for (size_t i = 0; i < ndim; i++) {
if (ellipsis_mask[i]) {
begins[i] = 0;
ends[i] = in_shape[i];
break;
}
}
}
int StridedSliceGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 5, 1);
@ -97,7 +116,10 @@ int StridedSliceGradInferShape(const TensorC *const *inputs, size_t inputs_size,
for (size_t i = 0; i < ndim_; i++) {
if (new_axis_mask_[i]) {
ndim_ += 1;
ShapeInsert(in_shape_, &in_shape_size, i, 1);
int ret = ShapeInsert(in_shape_, &in_shape_size, i, 1);
if (ret != NNACL_OK) {
return NNACL_ERR;
}
begins_[i] = 0;
ends_[i] = 1;
strides_[i] = 1;
@ -111,27 +133,7 @@ int StridedSliceGradInferShape(const TensorC *const *inputs, size_t inputs_size,
ellipsis_mask_[i] = false;
}
}
// ApplyBeginMask;
for (size_t i = 0; i < ndim_; i++) {
if (begins_mask_[i]) {
begins_[i] = 0;
}
}
// ApplyEndMask;
for (size_t i = 0; i < ndim_; i++) {
if (ends_mask_[i]) {
ends_[i] = in_shape_[i];
}
}
// ApplyEllipsisMask;
for (size_t i = 0; i < ndim_; i++) {
if (ellipsis_mask_[i]) {
begins_[i] = 0;
ends_[i] = in_shape_[i];
break;
}
}
ApplyBeginEndEllipsisMask(ndim_, begins_, begins_mask_, ends_, ends_mask_, ellipsis_mask_, in_shape_);
if (!inferflag) {
return NNACL_OK;
}

View File

@ -206,7 +206,10 @@ int ApplyNewAxisMask(StridedSliceTransferBuffer *transfer_buffer, StridedSlicePa
if (*out_shape_size >= MAX_SHAPE_SIZE) {
return NNACL_ERR;
}
ShapeInsert(in_shape, out_shape_size, i, 1);
int ret = ShapeInsert(in_shape, out_shape_size, i, 1);
if (ret != NNACL_OK) {
return NNACL_ERR;
}
transfer_buffer->begins_[i] = 0;
transfer_buffer->ends_[i] = 1;
transfer_buffer->strides_[i] = 1;

View File

@ -66,7 +66,10 @@ int TensorListFromTensorInferShape(const TensorC *const *inputs, size_t inputs_s
ShapeSet(output->element_shape_, &(output->element_shape_size_), ele_shape_ptr, GetElementNum(input1));
output->element_num_ = dim0;
MallocTensorListData(output, input0->data_type_, &tensor_shape);
int ret = MallocTensorListData(output, input0->data_type_, &tensor_shape);
if (ret != NNACL_OK) {
return NNACL_ERR;
}
free(tensor_shape.shape_);
free(tensor_shape.shape_size_);
return NNACL_OK;

View File

@ -73,7 +73,10 @@ int TensorListReserveInferShape(const TensorC *const *inputs, size_t inputs_size
tmp_shape.shape_size_[i] = 0;
tmp_shape.shape_[i] = NULL;
}
MallocTensorListData(output, kTypeUnknown, &tmp_shape);
int ret = MallocTensorListData(output, kTypeUnknown, &tmp_shape);
if (ret != NNACL_OK) {
return NNACL_ERR;
}
free(tmp_shape.shape_size_);
free(tmp_shape.shape_);
return NNACL_OK;

View File

@ -108,7 +108,10 @@ int TensorListSetItemInferShape(const TensorC *const *inputs, size_t inputs_size
out_shape.shape_[index] = (int *)(value_tensor->shape_);
out_shape.shape_size_[index] = value_tensor->shape_size_;
MallocTensorListData(output0, input0->tensors_data_type_, &out_shape);
int ret = MallocTensorListData(output0, input0->tensors_data_type_, &out_shape);
if (ret != NNACL_OK) {
return NNACL_ERR;
}
free(out_shape.shape_);
free(out_shape.shape_size_);
return NNACL_OK;

View File

@ -83,7 +83,10 @@ int TensorListStackInferShape(const TensorC *const *inputs, size_t inputs_size,
if (output_shape_size >= MAX_SHAPE_SIZE) {
return NNACL_ERR;
}
ShapeInsert(output_shape, &output_shape_size, 0, input0->element_num_);
int ret = ShapeInsert(output_shape, &output_shape_size, 0, input0->element_num_);
if (ret != NNACL_OK) {
return NNACL_ERR;
}
SetShapeArray(output, output_shape, output_shape_size);
return NNACL_OK;
}