forked from mindspore-Ecosystem/mindspore
!4940 modify conv to make fp16 run pass
Merge pull request !4940 from zhaozhenlong/lite/issue/modify_conv_for_fp16
This commit is contained in:
commit
689efef7f0
|
@ -235,13 +235,15 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::tensor::Ten
|
|||
conv_param->input_w_ = inputs.front()->Width();
|
||||
conv_param->output_h_ = outputs.front()->Height();
|
||||
conv_param->output_w_ = outputs.front()->Width();
|
||||
bool prefer_flag = false;
|
||||
if (conv_param->output_h_ * conv_param->output_w_ > 64) {
|
||||
prefer_flag = true;
|
||||
}
|
||||
// bool prefer_flag = false;
|
||||
// if (conv_param->output_h_ * conv_param->output_w_ > 64) {
|
||||
// prefer_flag = true;
|
||||
// }
|
||||
|
||||
kernel::LiteKernel *kernel = nullptr;
|
||||
if (kernel_h == 1 && kernel_w == 1) {
|
||||
if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) {
|
||||
kernel = new (std::nothrow) kernel::Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
} else if (kernel_h == 1 && kernel_w == 1) {
|
||||
// kernel = new (std::nothrow) kernel::Convolution1x1FP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
} else {
|
||||
|
@ -253,9 +255,6 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::tensor::Ten
|
|||
if (use_winograd) {
|
||||
kernel = new (std::nothrow)
|
||||
kernel::ConvolutionWinogradFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive, out_unit);
|
||||
} else if (prefer_flag && kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 &&
|
||||
dilation_w == 1) {
|
||||
kernel = new (std::nothrow) kernel::Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
}
|
||||
if (kernel_h != 1 && kernel_w != 1 && !use_winograd) {
|
||||
kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
|
|
Loading…
Reference in New Issue