forked from mindspore-Ecosystem/mindspore
!19933 [MS][LITE][Develop] optimize common convolution kernel
Merge pull request !19933 from sunsuodong/optimize_kernel_1.3
This commit is contained in:
commit
a7af3bf197
|
@ -27,31 +27,34 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_
|
|||
#else
|
||||
const int tile_n = 12;
|
||||
#endif
|
||||
int out_channel = conv_param->output_channel_;
|
||||
int output_count = conv_param->output_h_ * conv_param->output_w_;
|
||||
int output_tile_count = UP_DIV(output_count, tile_n);
|
||||
int output_hw = conv_param->output_h_ * conv_param->output_w_;
|
||||
int block_per_thread = UP_DIV(UP_DIV(output_hw, tile_n), conv_param->thread_num_);
|
||||
int start_block = block_per_thread * task_id;
|
||||
int start_hw = start_block * tile_n;
|
||||
int end_hw = MSMIN(output_hw, (start_block + block_per_thread) * tile_n);
|
||||
if (start_hw >= end_hw) {
|
||||
return;
|
||||
}
|
||||
int out_stride = conv_param->output_channel_ * tile_n;
|
||||
int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_;
|
||||
packed_input += task_id * deep * tile_n;
|
||||
col_major_input += task_id * deep * tile_n;
|
||||
size_t input_size = deep * tile_n * sizeof(float16_t);
|
||||
|
||||
for (int b = 0; b < conv_param->input_batch_; b++) {
|
||||
int in_batch_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_;
|
||||
int out_batch_offset = b * out_channel * output_count;
|
||||
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) {
|
||||
int start_index = thread_id * tile_n;
|
||||
int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n;
|
||||
float16_t *gemm_input = packed_input + task_id * deep * tile_n;
|
||||
float16_t *col_major_gemm_input = col_major_input + task_id * deep * tile_n;
|
||||
size_t packed_input_size = deep * tile_n * sizeof(float16_t);
|
||||
memset(gemm_input, 0, packed_input_size);
|
||||
Im2ColPackUnitFp16(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index);
|
||||
|
||||
int out_offset = thread_id * tile_n * out_channel + out_batch_offset;
|
||||
int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_;
|
||||
int out_offset = b * conv_param->output_channel_ * output_hw + start_hw * conv_param->output_channel_;
|
||||
for (int i = start_hw; i < end_hw; i += tile_n, out_offset += out_stride) {
|
||||
int real_cal_row = MSMIN(output_hw - i, tile_n);
|
||||
memset(packed_input, 0, input_size);
|
||||
Im2ColPackUnitFp16(input_data + in_offset, conv_param, packed_input, real_cal_row, i);
|
||||
#ifdef ENABLE_ARM64
|
||||
RowMajor2Col16MajorFp16Opt(gemm_input, col_major_gemm_input, tile_n, deep);
|
||||
RowMajor2Col16MajorFp16Opt(packed_input, col_major_input, tile_n, deep);
|
||||
#else
|
||||
RowMajor2Col12MajorFp16Opt(gemm_input, col_major_gemm_input, tile_n, deep);
|
||||
RowMajor2Col12MajorFp16Opt(packed_input, col_major_input, tile_n, deep);
|
||||
#endif
|
||||
MatMulFp16(col_major_gemm_input, packed_weight, output_data + out_offset, bias_data, conv_param->act_type_, deep,
|
||||
real_cal_num, out_channel, out_channel, OutType_Nhwc);
|
||||
MatMulFp16(col_major_input, packed_weight, output_data + out_offset, bias_data, conv_param->act_type_, deep,
|
||||
real_cal_row, conv_param->output_channel_, conv_param->output_channel_, OutType_Nhwc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,12 +24,10 @@
|
|||
// fp32 conv common
|
||||
void ConvFp32(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data,
|
||||
float *col_major_input, float *output_data, int task_id, const ConvParameter *conv_param) {
|
||||
int out_channel = conv_param->output_channel_;
|
||||
int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_;
|
||||
int output_count = conv_param->output_h_ * conv_param->output_w_;
|
||||
if (conv_param->thread_num_ == 0) {
|
||||
return;
|
||||
}
|
||||
int output_hw = conv_param->output_h_ * conv_param->output_w_;
|
||||
Row2ColMajorFuncPtr Row2ColMajor = NULL;
|
||||
#ifdef ENABLE_AVX
|
||||
const int cal_num = C6NUM;
|
||||
|
@ -40,11 +38,11 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_
|
|||
#elif defined(ENABLE_ARM64)
|
||||
int cal_num = 0;
|
||||
MatmulFloatOptFuncPtr MatmulFloatOpt = NULL;
|
||||
if (output_count <= C4NUM) {
|
||||
if (output_hw <= C4NUM) {
|
||||
cal_num = C4NUM;
|
||||
Row2ColMajor = RowMajor2Col4Major;
|
||||
MatmulFloatOpt = MatmulFloatNeon64OptRow4;
|
||||
} else if (output_count <= C8NUM) {
|
||||
} else if (output_hw <= C8NUM) {
|
||||
cal_num = C8NUM;
|
||||
Row2ColMajor = RowMajor2Col8Major;
|
||||
MatmulFloatOpt = MatmulFloatNeon64OptRow8;
|
||||
|
@ -60,43 +58,46 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_
|
|||
const int cal_num = C12NUM;
|
||||
Row2ColMajor = RowMajor2Col12Major;
|
||||
#endif
|
||||
int output_tile_count = UP_DIV(output_count, cal_num);
|
||||
|
||||
int block_per_thread = UP_DIV(UP_DIV(output_hw, cal_num), conv_param->thread_num_);
|
||||
int start_block = block_per_thread * task_id;
|
||||
int start_hw = start_block * cal_num;
|
||||
int end_hw = MSMIN(output_hw, (start_block + block_per_thread) * cal_num);
|
||||
if (start_hw >= end_hw) {
|
||||
return;
|
||||
}
|
||||
int out_stride = conv_param->output_channel_ * cal_num;
|
||||
int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_;
|
||||
packed_input += task_id * deep * cal_num;
|
||||
col_major_input += task_id * deep * cal_num;
|
||||
size_t input_size = deep * cal_num * sizeof(float);
|
||||
|
||||
for (int b = 0; b < conv_param->input_batch_; b++) {
|
||||
int in_batch_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_;
|
||||
int out_batch_offset = b * out_channel * output_count;
|
||||
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) {
|
||||
int start_index = thread_id * cal_num;
|
||||
int real_cal_num = (output_count - start_index) < cal_num ? (output_count - start_index) : cal_num;
|
||||
if (real_cal_num <= 0) {
|
||||
return;
|
||||
}
|
||||
float *gemm_input = packed_input + task_id * deep * cal_num;
|
||||
float *col_major_gemm_input = col_major_input + task_id * deep * cal_num;
|
||||
size_t packed_input_size = deep * cal_num * sizeof(float);
|
||||
memset(gemm_input, 0, packed_input_size);
|
||||
Im2ColPackUnitFp32(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index);
|
||||
|
||||
int out_offset = thread_id * cal_num * out_channel + out_batch_offset;
|
||||
int out_channel = conv_param->output_channel_;
|
||||
int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_;
|
||||
int out_offset = b * out_channel * output_hw + start_hw * out_channel;
|
||||
for (int i = start_hw; i < end_hw; i += cal_num, out_offset += out_stride) {
|
||||
int real_cal_row = MSMIN(output_hw - i, cal_num);
|
||||
memset(packed_input, 0, input_size);
|
||||
Im2ColPackUnitFp32(input_data + in_offset, conv_param, packed_input, real_cal_row, i);
|
||||
Row2ColMajor(packed_input, col_major_input, cal_num, deep);
|
||||
float *gemm_output = output_data + out_offset;
|
||||
|
||||
Row2ColMajor(gemm_input, col_major_gemm_input, cal_num, deep);
|
||||
// x86 func param types are different
|
||||
#if ENABLE_AVX
|
||||
MatmulFloatAvxOpt(col_major_gemm_input, packed_weight, gemm_output, bias_data, (size_t)conv_param->act_type_,
|
||||
deep, real_cal_num, out_channel, (size_t)out_channel, (size_t)OutType_Nhwc);
|
||||
MatmulFloatAvxOpt(col_major_input, packed_weight, gemm_output, bias_data, (size_t)conv_param->act_type_, deep,
|
||||
real_cal_row, out_channel, (size_t)out_channel, (size_t)OutType_Nhwc);
|
||||
#elif ENABLE_SSE
|
||||
MatmulFloatSse64Opt(col_major_gemm_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep,
|
||||
real_cal_num, out_channel, (size_t)out_channel, (int)OutType_Nhwc);
|
||||
MatmulFloatSse64Opt(col_major_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep,
|
||||
real_cal_row, out_channel, (size_t)out_channel, (int)OutType_Nhwc);
|
||||
#elif ENABLE_ARM32
|
||||
MatmulFloatNeon32Opt12x4(col_major_gemm_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_,
|
||||
deep, real_cal_num, out_channel, out_channel, OutType_Nhwc);
|
||||
MatmulFloatNeon32Opt12x4(col_major_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep,
|
||||
real_cal_row, out_channel, out_channel, OutType_Nhwc);
|
||||
#elif ENABLE_ARM64
|
||||
MatmulFloatOpt(col_major_gemm_input, packed_weight, gemm_output, bias_data, conv_param->act_type_, deep,
|
||||
real_cal_num, out_channel, out_channel, OutType_Nhwc);
|
||||
MatmulFloatOpt(col_major_input, packed_weight, gemm_output, bias_data, conv_param->act_type_, deep, real_cal_row,
|
||||
out_channel, out_channel, OutType_Nhwc);
|
||||
#else
|
||||
MatMul12x8(col_major_gemm_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep,
|
||||
real_cal_num, out_channel, out_channel, OutType_Nhwc);
|
||||
MatMul12x8(col_major_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep, real_cal_row,
|
||||
out_channel, out_channel, OutType_Nhwc);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue