!21002 [MS][LITE][STABLE]support NCHW input tensor

Merge pull request !21002 from chenjianping/kernel_reg
This commit is contained in:
i-robot 2021-08-02 01:29:47 +00:00 committed by Gitee
commit 799772455b
21 changed files with 67 additions and 11 deletions

View File

@ -24,6 +24,7 @@ typedef enum ErrorCodeCommonEnum {
NNACL_PARAM_INVALID, NNACL_PARAM_INVALID,
NNACL_INFER_INVALID, NNACL_INFER_INVALID,
NNACL_INPUT_TENSOR_ERROR, NNACL_INPUT_TENSOR_ERROR,
NNACL_FORMAT_ERROR,
NNACL_COMMON_END = 9999 NNACL_COMMON_END = 9999
} ErrorCodeCommonEnum; } ErrorCodeCommonEnum;

View File

@ -122,7 +122,7 @@ int BatchToSpaceInferShape(const TensorC *const *inputs, size_t inputs_size, Ten
const TensorC *input = inputs[0]; const TensorC *input = inputs[0];
if (input->format_ != Format_NHWC) { if (input->format_ != Format_NHWC) {
return NNACL_ERR; return NNACL_FORMAT_ERROR;
} }
SetDataTypeFormat(outputs[0], input); SetDataTypeFormat(outputs[0], input);
if (!InferFlag(inputs, inputs_size)) { if (!InferFlag(inputs, inputs_size)) {

View File

@ -25,6 +25,9 @@ int BnGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
} }
const TensorC *in = inputs[1]; const TensorC *in = inputs[1];
if (inputs[0]->format_ != Format_NHWC || in->format_ != Format_NHWC) {
return NNACL_FORMAT_ERROR;
}
const TensorC *scale = inputs[2]; const TensorC *scale = inputs[2];
if (in->shape_size_ != 4) { if (in->shape_size_ != 4) {
return NNACL_INPUT_TENSOR_ERROR; return NNACL_INPUT_TENSOR_ERROR;

View File

@ -358,6 +358,22 @@ int CommonInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
return NNACL_OK; return NNACL_OK;
} }
int CommonInferShapeWithNHWC(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
if (parameter == NULL || inputs[0] == NULL || outputs[0] == NULL) {
return NNACL_NULL_PTR;
}
if (inputs[0]->format_ != Format_NHWC) {
return NNACL_FORMAT_ERROR;
}
SetDataTypeFormat(outputs[0], inputs[0]);
if (!InferFlag(inputs, inputs_size)) {
return NNACL_INFER_INVALID;
}
SetShapeTensor(outputs[0], inputs[0]);
return NNACL_OK;
}
int FftInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, int FftInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
const OpParameter *parameter) { const OpParameter *parameter) {
int ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); int ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1);
@ -523,7 +539,7 @@ REG_INFER(LeakyRelu, PrimType_LeakyRelu, CommonInferShape)
REG_INFER(Log, PrimType_Log, CommonInferShape) REG_INFER(Log, PrimType_Log, CommonInferShape)
REG_INFER(LogGrad, PrimType_LogGrad, CommonInferShape) REG_INFER(LogGrad, PrimType_LogGrad, CommonInferShape)
REG_INFER(LogicalNot, PrimType_LogicalNot, CommonInferShape) REG_INFER(LogicalNot, PrimType_LogicalNot, CommonInferShape)
REG_INFER(LRN, PrimType_LRN, CommonInferShape) REG_INFER(LRN, PrimType_LRN, CommonInferShapeWithNHWC)
REG_INFER(L2Normalize, PrimType_L2NormalizeFusion, CommonInferShape) REG_INFER(L2Normalize, PrimType_L2NormalizeFusion, CommonInferShape)
REG_INFER(Neg, PrimType_Neg, CommonInferShape) REG_INFER(Neg, PrimType_Neg, CommonInferShape)
REG_INFER(NegGrad, PrimType_NegGrad, CommonInferShape) REG_INFER(NegGrad, PrimType_NegGrad, CommonInferShape)

View File

@ -47,8 +47,11 @@ typedef enum FormatC {
Format_NC4 = 12, Format_NC4 = 12,
Format_NC4HW4 = 13, Format_NC4HW4 = 13,
Format_NUM_OF_FORMAT = 14, Format_NUM_OF_FORMAT = 14,
Format_NCDHW = 15,
Format_NWC = 16,
Format_NCW = 17,
Format_MIN = Format_NCHW, Format_MIN = Format_NCHW,
Format_MAX = Format_NUM_OF_FORMAT Format_MAX = Format_NCW
} FormatC; } FormatC;
typedef enum TypeIdC { typedef enum TypeIdC {

View File

@ -26,6 +26,9 @@ int Conv2dGradFilterInferShape(const TensorC *const *inputs, size_t inputs_size,
if (inputs_size < 3 || outputs_size != 1) { if (inputs_size < 3 || outputs_size != 1) {
return NNACL_ERR; return NNACL_ERR;
} }
if (inputs[0]->format_ != Format_NHWC || inputs[1]->format_ != Format_NHWC) {
return NNACL_FORMAT_ERROR;
}
SetDataTypeFormat(outputs[0], inputs[0]); SetDataTypeFormat(outputs[0], inputs[0]);
if (inputs[2]->shape_size_ < 1 || inputs[2]->data_ == NULL) { if (inputs[2]->shape_size_ < 1 || inputs[2]->data_ == NULL) {

View File

@ -32,6 +32,9 @@ int Conv2dGradInputInferShape(const TensorC *const *inputs, size_t inputs_size,
if (in0 == NULL || out == NULL) { if (in0 == NULL || out == NULL) {
return NNACL_NULL_PTR; return NNACL_NULL_PTR;
} }
if (in0->format_ != Format_NHWC) {
return NNACL_FORMAT_ERROR;
}
SetDataTypeFormat(out, in0); SetDataTypeFormat(out, in0);
if (inputs[2]->shape_size_ < 1 || inputs[2]->data_ == NULL) { if (inputs[2]->shape_size_ < 1 || inputs[2]->data_ == NULL) {

View File

@ -62,6 +62,9 @@ int Conv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
} }
const TensorC *input_tensor = inputs[0]; const TensorC *input_tensor = inputs[0];
if (input_tensor->format_ != Format_NHWC && input_tensor->format_ != Format_KHWC) {
return NNACL_FORMAT_ERROR;
}
const TensorC *weight_tensor = inputs[1]; const TensorC *weight_tensor = inputs[1];
TensorC *out_tensor = outputs[0]; TensorC *out_tensor = outputs[0];

View File

@ -25,6 +25,9 @@ int Deconv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC
} }
const TensorC *input = inputs[0]; const TensorC *input = inputs[0];
if (input->format_ != Format_NHWC) {
return NNACL_FORMAT_ERROR;
}
const TensorC *weight = inputs[1]; const TensorC *weight = inputs[1];
TensorC *output = outputs[0]; TensorC *output = outputs[0];
output->format_ = input->format_; output->format_ = input->format_;

View File

@ -26,7 +26,7 @@ int DepthToSpaceInferShape(const TensorC *const *inputs, size_t inputs_size, Ten
const TensorC *input = inputs[0]; const TensorC *input = inputs[0];
if (input->format_ != Format_NHWC) { if (input->format_ != Format_NHWC) {
return NNACL_ERR; return NNACL_FORMAT_ERROR;
} }
SetDataTypeFormat(outputs[0], input); SetDataTypeFormat(outputs[0], input);
DepthToSpaceParameter *param = (DepthToSpaceParameter *)parameter; DepthToSpaceParameter *param = (DepthToSpaceParameter *)parameter;

View File

@ -26,6 +26,9 @@ int PoolingGradInferShape(const TensorC *const *inputs, size_t inputs_size, Tens
} }
const TensorC *input = inputs[0]; const TensorC *input = inputs[0];
if (input->format_ != Format_NHWC) {
return NNACL_FORMAT_ERROR;
}
int input_h = input->shape_[1]; int input_h = input->shape_[1];
int input_w = input->shape_[2]; int input_w = input->shape_[2];
if (input->shape_size_ != 4) { if (input->shape_size_ != 4) {

View File

@ -26,6 +26,9 @@ int PoolingInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC
} }
const TensorC *input = inputs[0]; const TensorC *input = inputs[0];
if (input->format_ != Format_NHWC) {
return NNACL_FORMAT_ERROR;
}
TensorC *output = outputs[0]; TensorC *output = outputs[0];
SetDataTypeFormat(output, input); SetDataTypeFormat(output, input);
PoolingParameter *param = (PoolingParameter *)parameter; PoolingParameter *param = (PoolingParameter *)parameter;

View File

@ -25,6 +25,9 @@ int ResizeGradInferShape(const TensorC *const *inputs, size_t inputs_size, Tenso
} }
const TensorC *input = inputs[0]; const TensorC *input = inputs[0];
if (input->format_ != Format_NHWC) {
return NNACL_FORMAT_ERROR;
}
if (input->shape_size_ != 4) { if (input->shape_size_ != 4) {
return NNACL_ERR; return NNACL_ERR;
} }

View File

@ -123,6 +123,9 @@ int ResizeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
} }
const TensorC *input = inputs[0]; const TensorC *input = inputs[0];
if (input->format_ != Format_NHWC) {
return NNACL_FORMAT_ERROR;
}
TensorC *output = outputs[0]; TensorC *output = outputs[0];
SetDataTypeFormat(output, input); SetDataTypeFormat(output, input);
if (!InferFlag(inputs, inputs_size)) { if (!InferFlag(inputs, inputs_size)) {

View File

@ -28,6 +28,9 @@ int ROIPoolingInferShape(const TensorC *const *inputs, size_t inputs_size, Tenso
} }
const TensorC *input = inputs[0]; const TensorC *input = inputs[0];
if (input->format_ != Format_NHWC) {
return NNACL_FORMAT_ERROR;
}
const TensorC *roi = inputs[1]; const TensorC *roi = inputs[1];
TensorC *output = outputs[0]; TensorC *output = outputs[0];
SetDataTypeFormat(output, input); SetDataTypeFormat(output, input);

View File

@ -26,7 +26,7 @@ int SpaceToBatchInferShape(const TensorC *const *inputs, size_t inputs_size, Ten
const TensorC *input = inputs[0]; const TensorC *input = inputs[0];
if (input->format_ != Format_NHWC) { if (input->format_ != Format_NHWC) {
return NNACL_ERR; return NNACL_FORMAT_ERROR;
} }
SetDataTypeFormat(outputs[0], input); SetDataTypeFormat(outputs[0], input);
SpaceToBatchParameter *param = (SpaceToBatchParameter *)parameter; SpaceToBatchParameter *param = (SpaceToBatchParameter *)parameter;

View File

@ -21,6 +21,9 @@
int SpaceSetOutputShapeFromParam(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, int SpaceSetOutputShapeFromParam(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs,
size_t outputs_size, OpParameter *parameter) { size_t outputs_size, OpParameter *parameter) {
const TensorC *input = inputs[0]; const TensorC *input = inputs[0];
if (input->format_ != Format_NHWC) {
return NNACL_FORMAT_ERROR;
}
if (input->shape_size_ != 4) { if (input->shape_size_ != 4) {
return NNACL_ERR; return NNACL_ERR;
} }

View File

@ -27,7 +27,7 @@ int SpaceToDepthInferShape(const TensorC *const *inputs, size_t inputs_size, Ten
const TensorC *input = inputs[0]; const TensorC *input = inputs[0];
if (input->format_ != Format_NHWC) { if (input->format_ != Format_NHWC) {
return NNACL_ERR; return NNACL_FORMAT_ERROR;
} }
SetDataTypeFormat(outputs[0], input); SetDataTypeFormat(outputs[0], input);
SpaceToDepthParameter *param = (SpaceToDepthParameter *)parameter; SpaceToDepthParameter *param = (SpaceToDepthParameter *)parameter;

View File

@ -278,10 +278,6 @@ bool LiteModel::ModelVerify() const {
MS_LOG(ERROR) << "Tensor in all tensors is nullptr."; MS_LOG(ERROR) << "Tensor in all tensors is nullptr.";
return false; return false;
} }
if (tensor->format() != schema::Format_NHWC) {
MS_LOG(ERROR) << "Graph input tensor should be NHWC";
return false;
}
} }
if (std::any_of(this->output_indices_.begin(), this->output_indices_.end(), if (std::any_of(this->output_indices_.begin(), this->output_indices_.end(),

View File

@ -67,6 +67,10 @@ int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vecto
int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
OpParameter *parameter) { OpParameter *parameter) {
if (inputs.empty()) {
MS_LOG(ERROR) << "No input!";
return RET_ERROR;
}
std::vector<TensorC *> in_tensors; std::vector<TensorC *> in_tensors;
std::vector<TensorC *> out_tensors; std::vector<TensorC *> out_tensors;
if (parameter->type_ == schema::PrimitiveType_PartialFusion || parameter->type_ == schema::PrimitiveType_Switch || if (parameter->type_ == schema::PrimitiveType_PartialFusion || parameter->type_ == schema::PrimitiveType_Switch ||
@ -119,6 +123,9 @@ int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vecto
if (ret == NNACL_INFER_INVALID) { if (ret == NNACL_INFER_INVALID) {
return RET_INFER_INVALID; return RET_INFER_INVALID;
} else if (ret != NNACL_OK) { } else if (ret != NNACL_OK) {
if (ret == NNACL_FORMAT_ERROR) {
MS_LOG(ERROR) << "Unexpected input format " << inputs[0]->format();
}
return RET_INFER_ERR; return RET_INFER_ERR;
} }
return RET_OK; return RET_OK;

View File

@ -193,7 +193,7 @@ kernel::LiteKernel *GetLiteKernel(std::vector<Tensor *> inputs, std::vector<Tens
auto ret = KernelInferShape(inputs, *outputs, parameter); auto ret = KernelInferShape(inputs, *outputs, parameter);
if (ret != lite::RET_OK) { if (ret != lite::RET_OK) {
free(parameter); free(parameter);
MS_LOG(ERROR) << "infershape failed."; MS_LOG(ERROR) << "infershape failed!type: " << schema::EnumNamePrimitiveType(prim->value_type());
return nullptr; return nullptr;
} }
auto data_type = inputs.front()->data_type(); auto data_type = inputs.front()->data_type();