forked from mindspore-Ecosystem/mindspore
!28563 [MSLITE][DEVELOP] judge conv weight format
Merge pull request !28563 from yangruoqi713/master
This commit is contained in:
commit
0b668d10fb
|
@ -103,6 +103,9 @@ int Conv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
|
||||||
return NNACL_FORMAT_ERROR;
|
return NNACL_FORMAT_ERROR;
|
||||||
}
|
}
|
||||||
const TensorC *weight_tensor = inputs[1];
|
const TensorC *weight_tensor = inputs[1];
|
||||||
|
if (weight_tensor->format_ != Format_NHWC && weight_tensor->format_ != Format_KHWC) {
|
||||||
|
return NNACL_FORMAT_ERROR;
|
||||||
|
}
|
||||||
TensorC *out_tensor = outputs[0];
|
TensorC *out_tensor = outputs[0];
|
||||||
if (out_tensor->format_ != Format_NC4HW4) {
|
if (out_tensor->format_ != Format_NC4HW4) {
|
||||||
out_tensor->format_ = input_tensor->format_;
|
out_tensor->format_ = input_tensor->format_;
|
||||||
|
@ -123,9 +126,9 @@ int Conv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
|
||||||
if (input_tensor->shape_size_ == 0) {
|
if (input_tensor->shape_size_ == 0) {
|
||||||
return NNACL_INFER_INVALID;
|
return NNACL_INFER_INVALID;
|
||||||
}
|
}
|
||||||
int input_h = in_shape[1];
|
int input_h = in_shape[DIMENSION_1D];
|
||||||
int input_w = in_shape[2];
|
int input_w = in_shape[DIMENSION_2D];
|
||||||
int input_c = in_shape[3];
|
int input_c = in_shape[DIMENSION_3D];
|
||||||
int output_w = 0, output_h = 0;
|
int output_w = 0, output_h = 0;
|
||||||
|
|
||||||
int ret = CheckConvAttr(input_c, weight_tensor, param);
|
int ret = CheckConvAttr(input_c, weight_tensor, param);
|
||||||
|
@ -141,19 +144,19 @@ int Conv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
|
||||||
int out_shape[MAX_SHAPE_SIZE];
|
int out_shape[MAX_SHAPE_SIZE];
|
||||||
size_t out_shape_size = 0;
|
size_t out_shape_size = 0;
|
||||||
ShapeSet(out_shape, &out_shape_size, input_tensor->shape_, input_tensor->shape_size_);
|
ShapeSet(out_shape, &out_shape_size, input_tensor->shape_, input_tensor->shape_size_);
|
||||||
out_shape[1] = output_h >= 0 ? output_h : 1;
|
out_shape[DIMENSION_1D] = output_h >= 0 ? output_h : 1;
|
||||||
out_shape[2] = output_w >= 0 ? output_w : 1;
|
out_shape[DIMENSION_2D] = output_w >= 0 ? output_w : 1;
|
||||||
out_shape[3] = GetBatch(weight_tensor);
|
out_shape[DIMENSION_3D] = GetBatch(weight_tensor);
|
||||||
SetShapeArray(out_tensor, out_shape, out_shape_size);
|
SetShapeArray(out_tensor, out_shape, out_shape_size);
|
||||||
|
|
||||||
param->input_batch_ = in_shape[0];
|
param->input_batch_ = in_shape[DIMENSION_0D];
|
||||||
param->input_h_ = in_shape[1];
|
param->input_h_ = in_shape[DIMENSION_1D];
|
||||||
param->input_w_ = in_shape[2];
|
param->input_w_ = in_shape[DIMENSION_2D];
|
||||||
param->input_channel_ = in_shape[3];
|
param->input_channel_ = in_shape[DIMENSION_3D];
|
||||||
param->output_batch_ = out_shape[0];
|
param->output_batch_ = out_shape[DIMENSION_0D];
|
||||||
param->output_h_ = out_shape[1];
|
param->output_h_ = out_shape[DIMENSION_1D];
|
||||||
param->output_w_ = out_shape[2];
|
param->output_w_ = out_shape[DIMENSION_2D];
|
||||||
param->output_channel_ = out_shape[3];
|
param->output_channel_ = out_shape[DIMENSION_3D];
|
||||||
|
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue