fix conv2d_transpose bug

This commit is contained in:
greatpanc 2021-10-28 19:52:00 +08:00
parent 324c4f0ca2
commit 276aeb4566
1 changed files with 2 additions and 7 deletions

View File

@ -336,14 +336,9 @@ kernel::InnerKernel *OpenCLConv2dTransposeCreator(const std::vector<lite::Tensor
MS_CHECK_TRUE_RET(inputs.front() != nullptr, nullptr);
MS_CHECK_TRUE_RET(outputs.front() != nullptr, nullptr);
MS_ASSERT(!inputs.empty());
MS_ASSERT(!outputs.empty());
MS_ASSERT(inputs.front()->shape().size() == DIMENSION_4D);
MS_ASSERT(outputs.front()->shape().size() == DIMENSION_4D);
auto *conv_param = reinterpret_cast<ConvParameter *>(opParameter);
int input_channel = inputs.front()->shape().at(3);
int output_channel = outputs.front()->shape().at(3);
int input_channel = conv_param->input_channel_;
int output_channel = conv_param->output_channel_;
int group = conv_param->group_;
// case 1: depthwise Conv2dTranspose