!19933 [MS][LITE][Develop] optimize common convolution kernel

Merge pull request !19933 from sunsuodong/optimize_kernel_1.3
This commit is contained in:
i-robot 2021-07-10 09:47:03 +00:00 committed by Gitee
commit a7af3bf197
2 changed files with 56 additions and 52 deletions

View File

@ -27,31 +27,34 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_
#else
const int tile_n = 12;
#endif
int out_channel = conv_param->output_channel_;
int output_count = conv_param->output_h_ * conv_param->output_w_;
int output_tile_count = UP_DIV(output_count, tile_n);
int output_hw = conv_param->output_h_ * conv_param->output_w_;
int block_per_thread = UP_DIV(UP_DIV(output_hw, tile_n), conv_param->thread_num_);
int start_block = block_per_thread * task_id;
int start_hw = start_block * tile_n;
int end_hw = MSMIN(output_hw, (start_block + block_per_thread) * tile_n);
if (start_hw >= end_hw) {
return;
}
int out_stride = conv_param->output_channel_ * tile_n;
int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_;
packed_input += task_id * deep * tile_n;
col_major_input += task_id * deep * tile_n;
size_t input_size = deep * tile_n * sizeof(float16_t);
for (int b = 0; b < conv_param->input_batch_; b++) {
int in_batch_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_;
int out_batch_offset = b * out_channel * output_count;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) {
int start_index = thread_id * tile_n;
int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n;
float16_t *gemm_input = packed_input + task_id * deep * tile_n;
float16_t *col_major_gemm_input = col_major_input + task_id * deep * tile_n;
size_t packed_input_size = deep * tile_n * sizeof(float16_t);
memset(gemm_input, 0, packed_input_size);
Im2ColPackUnitFp16(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index);
int out_offset = thread_id * tile_n * out_channel + out_batch_offset;
int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_;
int out_offset = b * conv_param->output_channel_ * output_hw + start_hw * conv_param->output_channel_;
for (int i = start_hw; i < end_hw; i += tile_n, out_offset += out_stride) {
int real_cal_row = MSMIN(output_hw - i, tile_n);
memset(packed_input, 0, input_size);
Im2ColPackUnitFp16(input_data + in_offset, conv_param, packed_input, real_cal_row, i);
#ifdef ENABLE_ARM64
RowMajor2Col16MajorFp16Opt(gemm_input, col_major_gemm_input, tile_n, deep);
RowMajor2Col16MajorFp16Opt(packed_input, col_major_input, tile_n, deep);
#else
RowMajor2Col12MajorFp16Opt(gemm_input, col_major_gemm_input, tile_n, deep);
RowMajor2Col12MajorFp16Opt(packed_input, col_major_input, tile_n, deep);
#endif
MatMulFp16(col_major_gemm_input, packed_weight, output_data + out_offset, bias_data, conv_param->act_type_, deep,
real_cal_num, out_channel, out_channel, OutType_Nhwc);
MatMulFp16(col_major_input, packed_weight, output_data + out_offset, bias_data, conv_param->act_type_, deep,
real_cal_row, conv_param->output_channel_, conv_param->output_channel_, OutType_Nhwc);
}
}
}

View File

@ -24,12 +24,10 @@
// fp32 conv common
void ConvFp32(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data,
float *col_major_input, float *output_data, int task_id, const ConvParameter *conv_param) {
int out_channel = conv_param->output_channel_;
int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_;
int output_count = conv_param->output_h_ * conv_param->output_w_;
if (conv_param->thread_num_ == 0) {
return;
}
int output_hw = conv_param->output_h_ * conv_param->output_w_;
Row2ColMajorFuncPtr Row2ColMajor = NULL;
#ifdef ENABLE_AVX
const int cal_num = C6NUM;
@ -40,11 +38,11 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_
#elif defined(ENABLE_ARM64)
int cal_num = 0;
MatmulFloatOptFuncPtr MatmulFloatOpt = NULL;
if (output_count <= C4NUM) {
if (output_hw <= C4NUM) {
cal_num = C4NUM;
Row2ColMajor = RowMajor2Col4Major;
MatmulFloatOpt = MatmulFloatNeon64OptRow4;
} else if (output_count <= C8NUM) {
} else if (output_hw <= C8NUM) {
cal_num = C8NUM;
Row2ColMajor = RowMajor2Col8Major;
MatmulFloatOpt = MatmulFloatNeon64OptRow8;
@ -60,43 +58,46 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_
const int cal_num = C12NUM;
Row2ColMajor = RowMajor2Col12Major;
#endif
int output_tile_count = UP_DIV(output_count, cal_num);
int block_per_thread = UP_DIV(UP_DIV(output_hw, cal_num), conv_param->thread_num_);
int start_block = block_per_thread * task_id;
int start_hw = start_block * cal_num;
int end_hw = MSMIN(output_hw, (start_block + block_per_thread) * cal_num);
if (start_hw >= end_hw) {
return;
}
int out_stride = conv_param->output_channel_ * cal_num;
int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_;
packed_input += task_id * deep * cal_num;
col_major_input += task_id * deep * cal_num;
size_t input_size = deep * cal_num * sizeof(float);
for (int b = 0; b < conv_param->input_batch_; b++) {
int in_batch_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_;
int out_batch_offset = b * out_channel * output_count;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) {
int start_index = thread_id * cal_num;
int real_cal_num = (output_count - start_index) < cal_num ? (output_count - start_index) : cal_num;
if (real_cal_num <= 0) {
return;
}
float *gemm_input = packed_input + task_id * deep * cal_num;
float *col_major_gemm_input = col_major_input + task_id * deep * cal_num;
size_t packed_input_size = deep * cal_num * sizeof(float);
memset(gemm_input, 0, packed_input_size);
Im2ColPackUnitFp32(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index);
int out_offset = thread_id * cal_num * out_channel + out_batch_offset;
int out_channel = conv_param->output_channel_;
int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_;
int out_offset = b * out_channel * output_hw + start_hw * out_channel;
for (int i = start_hw; i < end_hw; i += cal_num, out_offset += out_stride) {
int real_cal_row = MSMIN(output_hw - i, cal_num);
memset(packed_input, 0, input_size);
Im2ColPackUnitFp32(input_data + in_offset, conv_param, packed_input, real_cal_row, i);
Row2ColMajor(packed_input, col_major_input, cal_num, deep);
float *gemm_output = output_data + out_offset;
Row2ColMajor(gemm_input, col_major_gemm_input, cal_num, deep);
// x86 func param types are different
#if ENABLE_AVX
MatmulFloatAvxOpt(col_major_gemm_input, packed_weight, gemm_output, bias_data, (size_t)conv_param->act_type_,
deep, real_cal_num, out_channel, (size_t)out_channel, (size_t)OutType_Nhwc);
MatmulFloatAvxOpt(col_major_input, packed_weight, gemm_output, bias_data, (size_t)conv_param->act_type_, deep,
real_cal_row, out_channel, (size_t)out_channel, (size_t)OutType_Nhwc);
#elif ENABLE_SSE
MatmulFloatSse64Opt(col_major_gemm_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep,
real_cal_num, out_channel, (size_t)out_channel, (int)OutType_Nhwc);
MatmulFloatSse64Opt(col_major_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep,
real_cal_row, out_channel, (size_t)out_channel, (int)OutType_Nhwc);
#elif ENABLE_ARM32
MatmulFloatNeon32Opt12x4(col_major_gemm_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_,
deep, real_cal_num, out_channel, out_channel, OutType_Nhwc);
MatmulFloatNeon32Opt12x4(col_major_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep,
real_cal_row, out_channel, out_channel, OutType_Nhwc);
#elif ENABLE_ARM64
MatmulFloatOpt(col_major_gemm_input, packed_weight, gemm_output, bias_data, conv_param->act_type_, deep,
real_cal_num, out_channel, out_channel, OutType_Nhwc);
MatmulFloatOpt(col_major_input, packed_weight, gemm_output, bias_data, conv_param->act_type_, deep, real_cal_row,
out_channel, out_channel, OutType_Nhwc);
#else
MatMul12x8(col_major_gemm_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep,
real_cal_num, out_channel, out_channel, OutType_Nhwc);
MatMul12x8(col_major_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep, real_cal_row,
out_channel, out_channel, OutType_Nhwc);
#endif
}
}