From dfeed7c02ec7caa2e7449d8092e16b6a6389e264 Mon Sep 17 00:00:00 2001 From: yangruoqi713 Date: Fri, 14 Jan 2022 15:21:56 +0800 Subject: [PATCH] [MSLITE][DEVELOP] judge conv weight format --- .../cpu/nnacl/infer/conv2d_infer.c | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/conv2d_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/conv2d_infer.c index ceae47e3276..9b189537efd 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/conv2d_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/conv2d_infer.c @@ -103,6 +103,9 @@ int Conv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * return NNACL_FORMAT_ERROR; } const TensorC *weight_tensor = inputs[1]; + if (weight_tensor->format_ != Format_NHWC && weight_tensor->format_ != Format_KHWC) { + return NNACL_FORMAT_ERROR; + } TensorC *out_tensor = outputs[0]; if (out_tensor->format_ != Format_NC4HW4) { out_tensor->format_ = input_tensor->format_; @@ -123,9 +126,9 @@ int Conv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * if (input_tensor->shape_size_ == 0) { return NNACL_INFER_INVALID; } - int input_h = in_shape[1]; - int input_w = in_shape[2]; - int input_c = in_shape[3]; + int input_h = in_shape[DIMENSION_1D]; + int input_w = in_shape[DIMENSION_2D]; + int input_c = in_shape[DIMENSION_3D]; int output_w = 0, output_h = 0; int ret = CheckConvAttr(input_c, weight_tensor, param); @@ -141,19 +144,19 @@ int Conv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * int out_shape[MAX_SHAPE_SIZE]; size_t out_shape_size = 0; ShapeSet(out_shape, &out_shape_size, input_tensor->shape_, input_tensor->shape_size_); - out_shape[1] = output_h >= 0 ? output_h : 1; - out_shape[2] = output_w >= 0 ? output_w : 1; - out_shape[3] = GetBatch(weight_tensor); + out_shape[DIMENSION_1D] = output_h >= 0 ? output_h : 1; + out_shape[DIMENSION_2D] = output_w >= 0 ? output_w : 1; + out_shape[DIMENSION_3D] = GetBatch(weight_tensor); SetShapeArray(out_tensor, out_shape, out_shape_size); - param->input_batch_ = in_shape[0]; - param->input_h_ = in_shape[1]; - param->input_w_ = in_shape[2]; - param->input_channel_ = in_shape[3]; - param->output_batch_ = out_shape[0]; - param->output_h_ = out_shape[1]; - param->output_w_ = out_shape[2]; - param->output_channel_ = out_shape[3]; + param->input_batch_ = in_shape[DIMENSION_0D]; + param->input_h_ = in_shape[DIMENSION_1D]; + param->input_w_ = in_shape[DIMENSION_2D]; + param->input_channel_ = in_shape[DIMENSION_3D]; + param->output_batch_ = out_shape[DIMENSION_0D]; + param->output_h_ = out_shape[DIMENSION_1D]; + param->output_w_ = out_shape[DIMENSION_2D]; + param->output_channel_ = out_shape[DIMENSION_3D]; return NNACL_OK; }