!4210 [MS][LITE] fix bug of arm cpu fp32 conv_depthwise: only support group equals output channel

Merge pull request !4210 from yangruoqi713/lite
This commit is contained in:
mindspore-ci-bot 2020-08-10 19:00:08 +08:00 committed by Gitee
commit 13a66805b3
5 changed files with 13 additions and 3 deletions

View File

@ -40,6 +40,7 @@ int DepthwiseConv2D::InferShape(std::vector<tensor::Tensor *> inputs_, std::vect
auto in_shape = input->shape(); auto in_shape = input->shape();
int input_h = in_shape.at(1); int input_h = in_shape.at(1);
int input_w = in_shape.at(2); int input_w = in_shape.at(2);
int input_channel = in_shape.at(3);
int output_w = 0, output_h = 0; int output_w = 0, output_h = 0;
auto conv_prim = this->primitive->value_as_DepthwiseConv2D(); auto conv_prim = this->primitive->value_as_DepthwiseConv2D();
@ -69,6 +70,10 @@ int DepthwiseConv2D::InferShape(std::vector<tensor::Tensor *> inputs_, std::vect
std::vector<int> out_shape{input->shape()}; std::vector<int> out_shape{input->shape()};
out_shape.at(1) = output_h; out_shape.at(1) = output_h;
out_shape.at(2) = output_w; out_shape.at(2) = output_w;
if (conv_prim->channelMultiplier() * input_channel != weight->shape()[0]) {
MS_LOG(ERROR) << "Conv depthwise only support group equals output channel.";
return RET_ERROR;
}
out_shape.at(3) = weight->shape()[0] * weight->shape()[3]; // in_channel * out_channel out_shape.at(3) = weight->shape()[0] * weight->shape()[3]; // in_channel * out_channel
output->set_shape(out_shape); output->set_shape(out_shape);

View File

@ -40,6 +40,7 @@ int DeconvDepthwiseConv2D::InferShape(std::vector<tensor::Tensor *> inputs_, std
auto in_shape = input->shape(); auto in_shape = input->shape();
int input_h = in_shape.at(1); int input_h = in_shape.at(1);
int input_w = in_shape.at(2); int input_w = in_shape.at(2);
int input_channel = in_shape.at(3);
int output_w = 0, output_h = 0; int output_w = 0, output_h = 0;
auto conv_prim = this->primitive->value_as_DeDepthwiseConv2D(); auto conv_prim = this->primitive->value_as_DeDepthwiseConv2D();
@ -58,6 +59,10 @@ int DeconvDepthwiseConv2D::InferShape(std::vector<tensor::Tensor *> inputs_, std
std::vector<int> out_shape{input->shape()}; std::vector<int> out_shape{input->shape()};
out_shape.at(1) = output_h; out_shape.at(1) = output_h;
out_shape.at(2) = output_w; out_shape.at(2) = output_w;
if (conv_prim->channelMultiplier() * input_channel != weight->shape()[0]) {
MS_LOG(ERROR) << "Conv depthwise only support group equals output channel.";
return RET_ERROR;
}
out_shape.at(3) = weight->shape()[0] * weight->shape()[3]; // in_channel * out_channel out_shape.at(3) = weight->shape()[0] * weight->shape()[3]; // in_channel * out_channel
output->set_shape(out_shape); output->set_shape(out_shape);

View File

@ -53,7 +53,7 @@ void InitSlidingParam(SlidingWindowParam *sliding, const ConvParameter *conv_par
sliding->in_step_ = conv_param->input_h_ * conv_param->input_w_ * sliding->block_channel_; // for batch loop sliding->in_step_ = conv_param->input_h_ * conv_param->input_w_ * sliding->block_channel_; // for batch loop
sliding->in_h_step_ = conv_param->input_w_ * sliding->block_channel_; sliding->in_h_step_ = conv_param->input_w_ * sliding->block_channel_;
sliding->in_sh_step_ = conv_param->input_w_ * sliding->block_channel_ * conv_param->stride_h_; // stride H sliding->in_sh_step_ = conv_param->input_w_ * sliding->block_channel_ * conv_param->stride_h_; // stride H
sliding->in_sw_step_ = sliding->block_channel_ * conv_param->stride_h_; // stride W sliding->in_sw_step_ = sliding->block_channel_ * conv_param->stride_w_; // stride W
sliding->in_kh_step_ = conv_param->input_w_ * sliding->block_channel_ * conv_param->dilation_h_; // kernel H sliding->in_kh_step_ = conv_param->input_w_ * sliding->block_channel_ * conv_param->dilation_h_; // kernel H
sliding->in_kw_step_ = sliding->block_channel_ * conv_param->dilation_w_; // kernel W sliding->in_kw_step_ = sliding->block_channel_ * conv_param->dilation_w_; // kernel W
sliding->kernel_step_ = conv_param->kernel_w_ * conv_param->kernel_h_ * block; sliding->kernel_step_ = conv_param->kernel_w_ * conv_param->kernel_h_ * block;

View File

@ -20,7 +20,7 @@
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
void CaffeConvolutionParser::ParseGroupConvolution(schema::CNodeT *op, schema::Conv2DT *attr) { void CaffeConvolutionParser::ParseGroupConvolution(schema::CNodeT *op, schema::Conv2DT *attr) {
if (attr == nullptr || attr->group == 1 || attr->group != attr->channelOut) { if (attr == nullptr || attr->group == 1) {
return; return;
} }
std::unique_ptr<schema::DepthwiseConv2DT> depthwiseConv2DParam(new schema::DepthwiseConv2DT()); std::unique_ptr<schema::DepthwiseConv2DT> depthwiseConv2DParam(new schema::DepthwiseConv2DT());

View File

@ -20,7 +20,7 @@
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
void CaffeDeconvolutionParser::ParseGroupDeconvolution(schema::CNodeT *op, schema::DeConv2DT *attr) { void CaffeDeconvolutionParser::ParseGroupDeconvolution(schema::CNodeT *op, schema::DeConv2DT *attr) {
if (attr == nullptr || attr->group == 1 || attr->group != attr->channelIn) { if (attr == nullptr || attr->group == 1) {
return; return;
} }