!28563 [MSLITE][DEVELOP] judge conv weight format

Merge pull request !28563 from yangruoqi713/master
This commit is contained in:
i-robot 2022-01-18 01:25:37 +00:00 committed by Gitee
commit 0b668d10fb
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 17 additions and 14 deletions

View File

@ -103,6 +103,9 @@ int Conv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
return NNACL_FORMAT_ERROR;
}
const TensorC *weight_tensor = inputs[1];
if (weight_tensor->format_ != Format_NHWC && weight_tensor->format_ != Format_KHWC) {
return NNACL_FORMAT_ERROR;
}
TensorC *out_tensor = outputs[0];
if (out_tensor->format_ != Format_NC4HW4) {
out_tensor->format_ = input_tensor->format_;
@ -123,9 +126,9 @@ int Conv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
if (input_tensor->shape_size_ == 0) {
return NNACL_INFER_INVALID;
}
int input_h = in_shape[1];
int input_w = in_shape[2];
int input_c = in_shape[3];
int input_h = in_shape[DIMENSION_1D];
int input_w = in_shape[DIMENSION_2D];
int input_c = in_shape[DIMENSION_3D];
int output_w = 0, output_h = 0;
int ret = CheckConvAttr(input_c, weight_tensor, param);
@ -141,19 +144,19 @@ int Conv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
int out_shape[MAX_SHAPE_SIZE];
size_t out_shape_size = 0;
ShapeSet(out_shape, &out_shape_size, input_tensor->shape_, input_tensor->shape_size_);
out_shape[1] = output_h >= 0 ? output_h : 1;
out_shape[2] = output_w >= 0 ? output_w : 1;
out_shape[3] = GetBatch(weight_tensor);
out_shape[DIMENSION_1D] = output_h >= 0 ? output_h : 1;
out_shape[DIMENSION_2D] = output_w >= 0 ? output_w : 1;
out_shape[DIMENSION_3D] = GetBatch(weight_tensor);
SetShapeArray(out_tensor, out_shape, out_shape_size);
param->input_batch_ = in_shape[0];
param->input_h_ = in_shape[1];
param->input_w_ = in_shape[2];
param->input_channel_ = in_shape[3];
param->output_batch_ = out_shape[0];
param->output_h_ = out_shape[1];
param->output_w_ = out_shape[2];
param->output_channel_ = out_shape[3];
param->input_batch_ = in_shape[DIMENSION_0D];
param->input_h_ = in_shape[DIMENSION_1D];
param->input_w_ = in_shape[DIMENSION_2D];
param->input_channel_ = in_shape[DIMENSION_3D];
param->output_batch_ = out_shape[DIMENSION_0D];
param->output_h_ = out_shape[DIMENSION_1D];
param->output_w_ = out_shape[DIMENSION_2D];
param->output_channel_ = out_shape[DIMENSION_3D];
return NNACL_OK;
}