diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_fp16.c index e42041f3f29..01fbe9e2fc5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_fp16.c @@ -27,31 +27,34 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_ #else const int tile_n = 12; #endif - int out_channel = conv_param->output_channel_; - int output_count = conv_param->output_h_ * conv_param->output_w_; - int output_tile_count = UP_DIV(output_count, tile_n); + int output_hw = conv_param->output_h_ * conv_param->output_w_; + int block_per_thread = UP_DIV(UP_DIV(output_hw, tile_n), conv_param->thread_num_); + int start_block = block_per_thread * task_id; + int start_hw = start_block * tile_n; + int end_hw = MSMIN(output_hw, (start_block + block_per_thread) * tile_n); + if (start_hw >= end_hw) { + return; + } + int out_stride = conv_param->output_channel_ * tile_n; int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; + packed_input += task_id * deep * tile_n; + col_major_input += task_id * deep * tile_n; + size_t input_size = deep * tile_n * sizeof(float16_t); for (int b = 0; b < conv_param->input_batch_; b++) { - int in_batch_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; - int out_batch_offset = b * out_channel * output_count; - for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { - int start_index = thread_id * tile_n; - int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n; - float16_t *gemm_input = packed_input + task_id * deep * tile_n; - float16_t *col_major_gemm_input = col_major_input + task_id * deep * tile_n; - size_t packed_input_size = deep * tile_n * sizeof(float16_t); - memset(gemm_input, 0, packed_input_size); - Im2ColPackUnitFp16(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index); - - int out_offset = thread_id * tile_n * out_channel + out_batch_offset; + int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; + int out_offset = b * conv_param->output_channel_ * output_hw + start_hw * conv_param->output_channel_; + for (int i = start_hw; i < end_hw; i += tile_n, out_offset += out_stride) { + int real_cal_row = MSMIN(output_hw - i, tile_n); + memset(packed_input, 0, input_size); + Im2ColPackUnitFp16(input_data + in_offset, conv_param, packed_input, real_cal_row, i); #ifdef ENABLE_ARM64 - RowMajor2Col16MajorFp16Opt(gemm_input, col_major_gemm_input, tile_n, deep); + RowMajor2Col16MajorFp16Opt(packed_input, col_major_input, tile_n, deep); #else - RowMajor2Col12MajorFp16Opt(gemm_input, col_major_gemm_input, tile_n, deep); + RowMajor2Col12MajorFp16Opt(packed_input, col_major_input, tile_n, deep); #endif - MatMulFp16(col_major_gemm_input, packed_weight, output_data + out_offset, bias_data, conv_param->act_type_, deep, - real_cal_num, out_channel, out_channel, OutType_Nhwc); + MatMulFp16(col_major_input, packed_weight, output_data + out_offset, bias_data, conv_param->act_type_, deep, + real_cal_row, conv_param->output_channel_, conv_param->output_channel_, OutType_Nhwc); } } } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_common_fp32.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_common_fp32.c index ae159e061a1..960946d4336 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_common_fp32.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_common_fp32.c @@ -24,12 +24,10 @@ // fp32 conv common void ConvFp32(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data, float *col_major_input, float *output_data, int task_id, const ConvParameter *conv_param) { - int out_channel = conv_param->output_channel_; - int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; - int output_count = conv_param->output_h_ * conv_param->output_w_; if (conv_param->thread_num_ == 0) { return; } + int output_hw = conv_param->output_h_ * conv_param->output_w_; Row2ColMajorFuncPtr Row2ColMajor = NULL; #ifdef ENABLE_AVX const int cal_num = C6NUM; @@ -40,11 +38,11 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_ #elif defined(ENABLE_ARM64) int cal_num = 0; MatmulFloatOptFuncPtr MatmulFloatOpt = NULL; - if (output_count <= C4NUM) { + if (output_hw <= C4NUM) { cal_num = C4NUM; Row2ColMajor = RowMajor2Col4Major; MatmulFloatOpt = MatmulFloatNeon64OptRow4; - } else if (output_count <= C8NUM) { + } else if (output_hw <= C8NUM) { cal_num = C8NUM; Row2ColMajor = RowMajor2Col8Major; MatmulFloatOpt = MatmulFloatNeon64OptRow8; @@ -60,43 +58,46 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_ const int cal_num = C12NUM; Row2ColMajor = RowMajor2Col12Major; #endif - int output_tile_count = UP_DIV(output_count, cal_num); + + int block_per_thread = UP_DIV(UP_DIV(output_hw, cal_num), conv_param->thread_num_); + int start_block = block_per_thread * task_id; + int start_hw = start_block * cal_num; + int end_hw = MSMIN(output_hw, (start_block + block_per_thread) * cal_num); + if (start_hw >= end_hw) { + return; + } + int out_stride = conv_param->output_channel_ * cal_num; + int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; + packed_input += task_id * deep * cal_num; + col_major_input += task_id * deep * cal_num; + size_t input_size = deep * cal_num * sizeof(float); for (int b = 0; b < conv_param->input_batch_; b++) { - int in_batch_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; - int out_batch_offset = b * out_channel * output_count; - for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { - int start_index = thread_id * cal_num; - int real_cal_num = (output_count - start_index) < cal_num ? (output_count - start_index) : cal_num; - if (real_cal_num <= 0) { - return; - } - float *gemm_input = packed_input + task_id * deep * cal_num; - float *col_major_gemm_input = col_major_input + task_id * deep * cal_num; - size_t packed_input_size = deep * cal_num * sizeof(float); - memset(gemm_input, 0, packed_input_size); - Im2ColPackUnitFp32(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index); - - int out_offset = thread_id * cal_num * out_channel + out_batch_offset; + int out_channel = conv_param->output_channel_; + int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; + int out_offset = b * out_channel * output_hw + start_hw * out_channel; + for (int i = start_hw; i < end_hw; i += cal_num, out_offset += out_stride) { + int real_cal_row = MSMIN(output_hw - i, cal_num); + memset(packed_input, 0, input_size); + Im2ColPackUnitFp32(input_data + in_offset, conv_param, packed_input, real_cal_row, i); + Row2ColMajor(packed_input, col_major_input, cal_num, deep); float *gemm_output = output_data + out_offset; - - Row2ColMajor(gemm_input, col_major_gemm_input, cal_num, deep); // x86 func param types are different #if ENABLE_AVX - MatmulFloatAvxOpt(col_major_gemm_input, packed_weight, gemm_output, bias_data, (size_t)conv_param->act_type_, - deep, real_cal_num, out_channel, (size_t)out_channel, (size_t)OutType_Nhwc); + MatmulFloatAvxOpt(col_major_input, packed_weight, gemm_output, bias_data, (size_t)conv_param->act_type_, deep, + real_cal_row, out_channel, (size_t)out_channel, (size_t)OutType_Nhwc); #elif ENABLE_SSE - MatmulFloatSse64Opt(col_major_gemm_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep, - real_cal_num, out_channel, (size_t)out_channel, (int)OutType_Nhwc); + MatmulFloatSse64Opt(col_major_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep, + real_cal_row, out_channel, (size_t)out_channel, (int)OutType_Nhwc); #elif ENABLE_ARM32 - MatmulFloatNeon32Opt12x4(col_major_gemm_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, - deep, real_cal_num, out_channel, out_channel, OutType_Nhwc); + MatmulFloatNeon32Opt12x4(col_major_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep, + real_cal_row, out_channel, out_channel, OutType_Nhwc); #elif ENABLE_ARM64 - MatmulFloatOpt(col_major_gemm_input, packed_weight, gemm_output, bias_data, conv_param->act_type_, deep, - real_cal_num, out_channel, out_channel, OutType_Nhwc); + MatmulFloatOpt(col_major_input, packed_weight, gemm_output, bias_data, conv_param->act_type_, deep, real_cal_row, + out_channel, out_channel, OutType_Nhwc); #else - MatMul12x8(col_major_gemm_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep, - real_cal_num, out_channel, out_channel, OutType_Nhwc); + MatMul12x8(col_major_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep, real_cal_row, + out_channel, out_channel, OutType_Nhwc); #endif } }