!4940 modify conv to make fp16 run pass

Merge pull request !4940 from zhaozhenlong/lite/issue/modify_conv_for_fp16
This commit is contained in:
mindspore-ci-bot 2020-08-21 21:16:31 +08:00 committed by Gitee
commit 689efef7f0
1 changed files with 7 additions and 8 deletions

View File

@ -235,13 +235,15 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::tensor::Ten
conv_param->input_w_ = inputs.front()->Width();
conv_param->output_h_ = outputs.front()->Height();
conv_param->output_w_ = outputs.front()->Width();
bool prefer_flag = false;
if (conv_param->output_h_ * conv_param->output_w_ > 64) {
prefer_flag = true;
}
// bool prefer_flag = false;
// if (conv_param->output_h_ * conv_param->output_w_ > 64) {
// prefer_flag = true;
// }
kernel::LiteKernel *kernel = nullptr;
if (kernel_h == 1 && kernel_w == 1) {
if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) {
kernel = new (std::nothrow) kernel::Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
} else if (kernel_h == 1 && kernel_w == 1) {
// kernel = new (std::nothrow) kernel::Convolution1x1FP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
} else {
@ -253,9 +255,6 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::tensor::Ten
if (use_winograd) {
kernel = new (std::nothrow)
kernel::ConvolutionWinogradFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive, out_unit);
} else if (prefer_flag && kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 &&
dilation_w == 1) {
kernel = new (std::nothrow) kernel::Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
}
if (kernel_h != 1 && kernel_w != 1 && !use_winograd) {
kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);