forked from mindspore-Ecosystem/mindspore
!28563 [MSLITE][DEVELOP] judge conv weight format
Merge pull request !28563 from yangruoqi713/master
This commit is contained in:
commit
0b668d10fb
|
@ -103,6 +103,9 @@ int Conv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
|
|||
return NNACL_FORMAT_ERROR;
|
||||
}
|
||||
const TensorC *weight_tensor = inputs[1];
|
||||
if (weight_tensor->format_ != Format_NHWC && weight_tensor->format_ != Format_KHWC) {
|
||||
return NNACL_FORMAT_ERROR;
|
||||
}
|
||||
TensorC *out_tensor = outputs[0];
|
||||
if (out_tensor->format_ != Format_NC4HW4) {
|
||||
out_tensor->format_ = input_tensor->format_;
|
||||
|
@ -123,9 +126,9 @@ int Conv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
|
|||
if (input_tensor->shape_size_ == 0) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
int input_h = in_shape[1];
|
||||
int input_w = in_shape[2];
|
||||
int input_c = in_shape[3];
|
||||
int input_h = in_shape[DIMENSION_1D];
|
||||
int input_w = in_shape[DIMENSION_2D];
|
||||
int input_c = in_shape[DIMENSION_3D];
|
||||
int output_w = 0, output_h = 0;
|
||||
|
||||
int ret = CheckConvAttr(input_c, weight_tensor, param);
|
||||
|
@ -141,19 +144,19 @@ int Conv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
|
|||
int out_shape[MAX_SHAPE_SIZE];
|
||||
size_t out_shape_size = 0;
|
||||
ShapeSet(out_shape, &out_shape_size, input_tensor->shape_, input_tensor->shape_size_);
|
||||
out_shape[1] = output_h >= 0 ? output_h : 1;
|
||||
out_shape[2] = output_w >= 0 ? output_w : 1;
|
||||
out_shape[3] = GetBatch(weight_tensor);
|
||||
out_shape[DIMENSION_1D] = output_h >= 0 ? output_h : 1;
|
||||
out_shape[DIMENSION_2D] = output_w >= 0 ? output_w : 1;
|
||||
out_shape[DIMENSION_3D] = GetBatch(weight_tensor);
|
||||
SetShapeArray(out_tensor, out_shape, out_shape_size);
|
||||
|
||||
param->input_batch_ = in_shape[0];
|
||||
param->input_h_ = in_shape[1];
|
||||
param->input_w_ = in_shape[2];
|
||||
param->input_channel_ = in_shape[3];
|
||||
param->output_batch_ = out_shape[0];
|
||||
param->output_h_ = out_shape[1];
|
||||
param->output_w_ = out_shape[2];
|
||||
param->output_channel_ = out_shape[3];
|
||||
param->input_batch_ = in_shape[DIMENSION_0D];
|
||||
param->input_h_ = in_shape[DIMENSION_1D];
|
||||
param->input_w_ = in_shape[DIMENSION_2D];
|
||||
param->input_channel_ = in_shape[DIMENSION_3D];
|
||||
param->output_batch_ = out_shape[DIMENSION_0D];
|
||||
param->output_h_ = out_shape[DIMENSION_1D];
|
||||
param->output_w_ = out_shape[DIMENSION_2D];
|
||||
param->output_channel_ = out_shape[DIMENSION_3D];
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue