From 4b0edb34ce74465fab4c6794a39a3b63b44d0261 Mon Sep 17 00:00:00 2001 From: yangruoqi713 Date: Mon, 17 Jan 2022 14:18:46 +0800 Subject: [PATCH] [MSLITE][DEVELOP] optimize conv winograd --- .jenkins/check/config/whitelizard.txt | 1 + .../cpu/nnacl/fp16/conv_fp16.c | 81 ++- .../cpu/nnacl/fp16/conv_fp16.h | 2 +- .../cpu/nnacl/fp16/winograd_transform_fp16.c | 182 ++++-- .../cpu/nnacl/fp16/winograd_transform_fp16.h | 4 + .../cpu/nnacl/fp16/winograd_utils_fp16.c | 335 +++++++++++ .../cpu/nnacl/fp16/winograd_utils_fp16.h | 76 +++ .../cpu/nnacl/fp32/conv_winograd_fp32.c | 96 ++-- .../cpu/nnacl/fp32/conv_winograd_fp32.h | 2 +- .../cpu/nnacl/fp32/winograd_transform.c | 152 +++-- .../cpu/nnacl/fp32/winograd_transform.h | 4 + .../cpu/nnacl/fp32/winograd_utils.c | 544 ++++++++++++++++++ .../cpu/nnacl/fp32/winograd_utils.h | 64 +++ .../fp32/convolution_winograd_fp32_coder.cc | 11 +- .../fp32/convolution_winograd_fp32_coder.h | 4 +- .../nnacl_serializer/nnacl_fp32_serializer.cc | 4 + .../nnacl_serializer/nnacl_fp32_serializer.h | 2 + .../wrapper/fp32/conv_winograd_fp32_wrapper.h | 30 + .../arm/fp16/convolution_winograd_fp16.cc | 29 +- .../arm/fp16/convolution_winograd_fp16.h | 10 +- .../arm/fp32/convolution_winograd_fp32.cc | 29 +- .../arm/fp32/convolution_winograd_fp32.h | 10 +- .../lite/test/config/models_caffe_fp16.cfg | 36 +- .../lite/test/config/models_onnx_fp16.cfg | 18 +- mindspore/lite/test/config/models_tf_fp16.cfg | 12 +- .../lite/test/config/models_tflite_fp16.cfg | 14 +- 26 files changed, 1521 insertions(+), 231 deletions(-) create mode 100644 mindspore/lite/micro/coder/wrapper/fp32/conv_winograd_fp32_wrapper.h diff --git a/.jenkins/check/config/whitelizard.txt b/.jenkins/check/config/whitelizard.txt index 05b77108e23..430e671298e 100644 --- a/.jenkins/check/config/whitelizard.txt +++ b/.jenkins/check/config/whitelizard.txt @@ -195,3 +195,4 @@ mindspore/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition mindspore/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition.cc:mindspore::parallel::PartitionNode mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/instance_norm_fp16.c:InstanceNormNC8HW8Fp16 mindspore/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc:mindspore::kernel::MatmulFp32BaseCPUKernel::init_global_variable +mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_winograd_fp32.c:ConvWinogardFp32 diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_fp16.c index 31962281a81..6a5d6a59f26 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_fp16.c @@ -187,7 +187,7 @@ void Conv1x1OutNc8hw8MultiThreadByWeightFp16(const float16_t *input, float16_t * // fp16 convolution winograd void ConvWinogardFp16(const float16_t *input_data, const float16_t *trans_weight, const float16_t *bias_data, float16_t *output_data, TmpBufferAddressFp16 *buffer_list, int task_id, - const ConvParameter *conv_param, InputTransFp16Func in_func, OutputTransFp16Func out_func) { + const ConvParameter *conv_param, TransFp16FuncList trans_func) { #ifdef ENABLE_ARM64 const int tile_num = 16; #else @@ -196,6 +196,7 @@ void ConvWinogardFp16(const float16_t *input_data, const float16_t *trans_weight NNACL_CHECK_ZERO_RETURN(conv_param->output_unit_); NNACL_CHECK_ZERO_RETURN(conv_param->thread_num_); int in_channel = conv_param->input_channel_; + int input_unit = conv_param->input_unit_; int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_); int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_); int output_count = out_w_block * out_h_block; @@ -204,16 +205,12 @@ void ConvWinogardFp16(const float16_t *input_data, const float16_t *trans_weight NNACL_CHECK_ZERO_RETURN(real_tile); int output_tile_count = UP_DIV(output_count, real_tile); int oc8 = UP_DIV(conv_param->output_channel_, C8NUM); - int input_unit_square = conv_param->input_unit_ * conv_param->input_unit_; + int input_unit_square = input_unit * input_unit; - float16_t *trans_input = buffer_list[0]; - float16_t *gemm_out = buffer_list[1]; - float16_t *tmp_data = buffer_list[2]; - float16_t *col_buffer = buffer_list[3]; - int trans_input_offset = tile_num * input_unit_square * in_channel; - int gemm_out_offset = tile_num * input_unit_square * oc8 * C8NUM; - int tmp_data_offset = input_unit_square * C8NUM; - int col_buffer_offset = tile_num * in_channel; + float16_t *trans_input = buffer_list[0] + task_id * tile_num * input_unit_square * in_channel; + float16_t *gemm_out = buffer_list[1] + task_id * tile_num * input_unit_square * oc8 * C8NUM; + float16_t *tmp_data = buffer_list[2] + task_id * input_unit_square * C8NUM; + float16_t *col_buffer = buffer_list[3] + task_id * tile_num * in_channel; // step 1 : filter transform (pre-processed offline) // step 2 : input transform (online) for (int b = 0; b < conv_param->input_batch_; b++) { @@ -226,30 +223,64 @@ void ConvWinogardFp16(const float16_t *input_data, const float16_t *trans_weight if (cal_num <= 0) { return; } - WinogradInputTransformFp16(input_data + in_batch_offset, trans_input + task_id * trans_input_offset, - tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param, - in_func); - // step 3 : gemm - float16_t *src_ptr = trans_input + task_id * trans_input_offset; - float16_t *dst_ptr = gemm_out + task_id * gemm_out_offset; - float16_t *tmp_col_ptr = col_buffer + task_id * col_buffer_offset; - for (int i = 0; i < input_unit_square; ++i) { + #ifdef ENABLE_ARM64 - RowMajor2Col16MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel); + // Optimize input transform. Only valid for arm64, the tile num is 16. + // For arm32, the tile_num is 12. The function(InputTransform4x4Pack12Fp16) needs to be rewritten. + bool fused_pack = + (cal_num == tile_num) && (trans_func.in_step_func_ != NULL) && (trans_func.in_pack_func_ != NULL); + if (fused_pack) { + float16_t *opt_trans_input = + buffer_list[4] + task_id * tile_num * input_unit_square * UP_ROUND(in_channel, C8NUM); + WinogradInputTransformOptStepFp16(input_data + in_batch_offset, opt_trans_input, tmp_data, cal_num, + out_tile_index, out_w_block, conv_param, trans_func.in_step_func_); + + for (int w_index = 0; w_index < input_unit; w_index++) { + float16_t *src_w = opt_trans_input + w_index * input_unit * tile_num * C8NUM; + for (int c = 0; c < UP_DIV(in_channel, C8NUM); c++) { + int real_c = in_channel - c * C8NUM; + real_c = real_c > C8NUM ? C8NUM : real_c; + float16_t *src_c = src_w + c * input_unit_square * tile_num * C8NUM; + float16_t *dst_c = trans_input + c * tile_num * C8NUM; + trans_func.in_pack_func_(src_c, dst_c, C8NUM, in_channel * tile_num, real_c); + } + + for (int h_index = 0; h_index < input_unit; h_index++) { + const float16_t *gemm_input = trans_input + h_index * tile_num * in_channel; + int point_index = h_index * input_unit + w_index; + const float16_t *gemm_weight = trans_weight + point_index * in_channel * oc8 * C8NUM; + MatMulFp16(gemm_input, gemm_weight, gemm_out + point_index * C8NUM, NULL, 0, in_channel, cal_num, + oc8 * C8NUM, input_unit_square, OutType_TileC8); + } + } + } else { +#endif + WinogradInputTransformFp16(input_data + in_batch_offset, trans_input, tmp_data, cal_num, out_tile_index, + out_w_block, conv_param, trans_func.in_func_); + // step 3 : gemm + float16_t *src_ptr = trans_input; + float16_t *dst_ptr = gemm_out; + float16_t *tmp_col_ptr = col_buffer; + for (int i = 0; i < input_unit_square; ++i) { +#ifdef ENABLE_ARM64 + RowMajor2Col16MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel); #else RowMajor2Col12MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel); #endif - MatMulFp16(tmp_col_ptr, trans_weight + i * in_channel * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, in_channel, - cal_num, oc8 * C8NUM, input_unit_square, OutType_TileC8); + MatMulFp16(tmp_col_ptr, trans_weight + i * in_channel * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, in_channel, + cal_num, oc8 * C8NUM, input_unit_square, OutType_TileC8); + } +#ifdef ENABLE_ARM64 } +#endif // step 4 : output transform if (conv_param->out_format_ != NNACL_NC4HW4) { // nc4hw4 - WinogradOutputNHWCTransformFp16(gemm_out + task_id * gemm_out_offset, output_data + out_batch_offset, bias_data, - cal_num, out_tile_index, out_w_block, conv_param, out_func); + WinogradOutputNHWCTransformFp16(gemm_out, output_data + out_batch_offset, bias_data, cal_num, out_tile_index, + out_w_block, conv_param, trans_func.out_func_); } else { - WinogradOutputNC8HW8TransformFp16(gemm_out + task_id * gemm_out_offset, output_data + out_batch_offset, - bias_data, cal_num, out_tile_index, out_w_block, conv_param, out_func); + WinogradOutputNC8HW8TransformFp16(gemm_out, output_data + out_batch_offset, bias_data, cal_num, out_tile_index, + out_w_block, conv_param, trans_func.out_func_); } } } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_fp16.h index 8315454e81e..32d52d7ac9b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_fp16.h @@ -49,7 +49,7 @@ void Conv1x1OutNc8hw8MultiThreadByWeightFp16(const float16_t *input, float16_t * // fp16 convolution winograd void ConvWinogardFp16(const float16_t *input_data, const float16_t *trans_weight, const float16_t *bias_data, float16_t *output_data, TmpBufferAddressFp16 *buffer_list, int task_id, - const ConvParameter *conv_param, InputTransFp16Func in_func, OutputTransFp16Func out_func); + const ConvParameter *conv_param, TransFp16FuncList trans_func); #ifdef __cplusplus } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/winograd_transform_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/winograd_transform_fp16.c index 61408fa645b..18577d401f6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/winograd_transform_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/winograd_transform_fp16.c @@ -16,6 +16,77 @@ #include "nnacl/fp16/winograd_transform_fp16.h" +void PrepareTransInputFp16(const float16_t *src_data, float16_t *dst_data, int interval_x_s, int interval_x_e, + int interval_y_s, int interval_y_e, int real_c, const ConvParameter *conv_param) { + int input_unit = conv_param->input_unit_; + int in_channel = conv_param->input_channel_; + int input_w = conv_param->input_w_; + + // clear tmp buffer + if (interval_x_e - interval_x_s != input_unit || interval_y_e - interval_y_s != input_unit) { + memset(dst_data, 0, input_unit * input_unit * C8NUM * sizeof(float16_t)); + } + + // get real input block with padding + if (real_c == C8NUM) { + for (int interval = interval_y_s; interval < interval_y_e; interval++) { + int src_y_offset = (interval * input_w + interval_x_s) * in_channel; + int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM; + for (int j = 0; j < (interval_x_e - interval_x_s); j++) { + int src_x_offset = src_y_offset + j * in_channel; + int dst_x_offset = dst_y_offset + j * C8NUM; + const float16_t *src_addr = src_data + src_x_offset; + float16_t *dst_addr = dst_data + dst_x_offset; +#ifdef ENABLE_NEON + vst1q_f16(dst_addr, vld1q_f16(src_addr)); +#else + for (int k = 0; k < C8NUM; k++) { + dst_addr[k] = src_addr[k]; + } +#endif + } + } + } else if (real_c < 8 && real_c >= 4) { + for (int interval = interval_y_s; interval < interval_y_e; interval++) { + int src_y_offset = (interval * input_w + interval_x_s) * in_channel; + int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM; + for (int j = 0; j < (interval_x_e - interval_x_s); j++) { + int src_x_offset = src_y_offset + j * in_channel; + int dst_x_offset = dst_y_offset + j * C8NUM; + const float16_t *src_addr = src_data + src_x_offset; + float16_t *dst_addr = dst_data + dst_x_offset; + int rc = real_c - 4; +#ifdef ENABLE_NEON + vst1_f16(dst_addr, vld1_f16(src_addr)); +#else + for (int k = 0; k < C4NUM; k++) { + dst_addr[k] = src_addr[k]; + } +#endif + src_addr += 4; + dst_addr += 4; + for (int i = 0; i < rc; ++i) { + dst_addr[i] = src_addr[i]; + } + } + } + } else { + for (int interval = interval_y_s; interval < interval_y_e; interval++) { + int src_y_offset = (interval * input_w + interval_x_s) * in_channel; + int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM; + for (int j = 0; j < (interval_x_e - interval_x_s); j++) { + int src_x_offset = src_y_offset + j * in_channel; + int dst_x_offset = dst_y_offset + j * C8NUM; + const float16_t *src_addr = src_data + src_x_offset; + float16_t *dst_addr = dst_data + dst_x_offset; + for (int k = 0; k < real_c; k++) { + dst_addr[k] = src_addr[k]; + } + } + } + } +} + // fp16 common winograd void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num, int out_tile_index, int out_w_block_num, const ConvParameter *conv_param, @@ -49,71 +120,11 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s); int dst_plane_offset = c * in_channel; for (int ic = 0; ic < ic8; ic++) { - // clear tmp buffer - memset(tmp_data, 0, input_unit * input_unit * C8NUM * sizeof(float16_t)); - int real_c = in_channel - ic * C8NUM; real_c = real_c > C8NUM ? C8NUM : real_c; - int src_ic8_offset = src_plane_offset + ic * C8NUM; - - // get real input block with padding - if (real_c == C8NUM) { - for (int interval = interval_y_s; interval < interval_y_e; interval++) { - int src_y_offset = src_ic8_offset + (interval * input_w + interval_x_s) * in_channel; - int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM; - for (int j = 0; j < (interval_x_e - interval_x_s); j++) { - int src_x_offset = src_y_offset + j * in_channel; - int dst_x_offset = dst_y_offset + j * C8NUM; - const float16_t *src_addr = input_data + src_x_offset; - float16_t *dst_addr = tmp_data + dst_x_offset; -#ifdef ENABLE_NEON - vst1q_f16(dst_addr, vld1q_f16(src_addr)); -#else - for (int k = 0; k < C8NUM; k++) { - dst_addr[k] = src_addr[k]; - } -#endif - } - } - } else if (real_c < 8 && real_c >= 4) { - for (int interval = interval_y_s; interval < interval_y_e; interval++) { - int src_y_offset = src_ic8_offset + (interval * input_w + interval_x_s) * in_channel; - int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM; - for (int j = 0; j < (interval_x_e - interval_x_s); j++) { - int src_x_offset = src_y_offset + j * in_channel; - int dst_x_offset = dst_y_offset + j * C8NUM; - const float16_t *src_addr = input_data + src_x_offset; - float16_t *dst_addr = tmp_data + dst_x_offset; - int rc = real_c - 4; -#ifdef ENABLE_NEON - vst1_f16(dst_addr, vld1_f16(src_addr)); -#else - for (int k = 0; k < C4NUM; k++) { - dst_addr[k] = src_addr[k]; - } -#endif - src_addr += 4; - dst_addr += 4; - for (int i = 0; i < rc; ++i) { - dst_addr[i] = src_addr[i]; - } - } - } - } else { - for (int interval = interval_y_s; interval < interval_y_e; interval++) { - int src_y_offset = src_ic8_offset + (interval * input_w + interval_x_s) * in_channel; - int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM; - for (int j = 0; j < (interval_x_e - interval_x_s); j++) { - int src_x_offset = src_y_offset + j * in_channel; - int dst_x_offset = dst_y_offset + j * C8NUM; - const float16_t *src_addr = input_data + src_x_offset; - float16_t *dst_addr = tmp_data + dst_x_offset; - for (int k = 0; k < real_c; k++) { - dst_addr[k] = src_addr[k]; - } - } - } - } + const float16_t *src_data = input_data + src_plane_offset + ic * C8NUM; + PrepareTransInputFp16(src_data, tmp_data, interval_x_s, interval_x_e, interval_y_s, interval_y_e, real_c, + conv_param); // input transform int dst_ic8_offset = dst_plane_offset + ic * C8NUM; @@ -125,6 +136,51 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in } // cal_tile_num loop } +// Only support arm64 +void WinogradInputTransformOptStepFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, + int cal_num, int out_tile_index, int out_w_block_num, + const ConvParameter *conv_param, InputTransStepFp16Func func) { + const int tile_num = 16; + int input_unit = conv_param->input_unit_; + int output_unit = conv_param->output_unit_; + int in_channel = conv_param->input_channel_; + int ic8 = UP_DIV(in_channel, C8NUM); + int pad_h = conv_param->pad_u_; + int pad_w = conv_param->pad_l_; + int input_h = conv_param->input_h_; + int input_w = conv_param->input_w_; + if (out_w_block_num == 0) { + return; + } + for (int c = 0; c < cal_num; c++) { // actual tiled number + int src_x_s = (out_tile_index % out_w_block_num) * output_unit - pad_w; + int src_y_s = (out_tile_index / out_w_block_num) * output_unit - pad_h; + int interval_x_s = src_x_s > 0 ? 0 : -src_x_s; + int interval_y_s = src_y_s > 0 ? 0 : -src_y_s; + int src_x_e = src_x_s + input_unit; + int src_y_e = src_y_s + input_unit; + int interval_x_e = src_x_e < input_w ? input_unit : (input_w - src_x_s); + int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s); + + int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s); + int dst_plane_offset = c * C8NUM; + for (int ic = 0; ic < ic8; ic++) { + int real_c = in_channel - ic * C8NUM; + real_c = real_c > C8NUM ? C8NUM : real_c; + const float16_t *src_data = input_data + src_plane_offset + ic * C8NUM; + PrepareTransInputFp16(src_data, tmp_data, interval_x_s, interval_x_e, interval_y_s, interval_y_e, real_c, + conv_param); + + // input transform + int dst_ic8_offset = dst_plane_offset + ic * tile_num * input_unit * input_unit * C8NUM; + size_t dst_step = input_unit * tile_num * C8NUM; + float16_t *trans_input_ptr = trans_input + dst_ic8_offset; + func(tmp_data, trans_input_ptr, C8NUM, dst_step, tile_num * C8NUM); + } + out_tile_index++; + } // cal_tile_num loop +} + void WinogradOutputNHWCTransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data, int cal_num, int out_tile_index, int output_unit_num, const ConvParameter *conv_param, OutputTransFp16Func func) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/winograd_transform_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/winograd_transform_fp16.h index 286b75b61c5..f7b06f1b9c4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/winograd_transform_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/winograd_transform_fp16.h @@ -33,6 +33,10 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in int out_tile_index, int out_w_block_num, const ConvParameter *conv_param, InputTransFp16Func func); +void WinogradInputTransformOptStepFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, + int cal_num, int out_tile_index, int out_w_block_num, + const ConvParameter *conv_param, InputTransStepFp16Func func); + void WinogradOutputNHWCTransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data, int cal_num, int out_tile_index, int output_unit_num, const ConvParameter *conv_param, OutputTransFp16Func func); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/winograd_utils_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/winograd_utils_fp16.c index 0c46bd323d2..af99832a92e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/winograd_utils_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/winograd_utils_fp16.c @@ -20,6 +20,38 @@ #define MIN_UNIT_FP16 2 #define MAX_UNIT_FP16 4 +#ifdef ENABLE_ARM64 +void transpose8(float16x8_t *s0, float16x8_t *s1, float16x8_t *s2, float16x8_t *s3, float16x8_t *s4, float16x8_t *s5, + float16x8_t *s6, float16x8_t *s7) { + float32x4_t m0 = (float32x4_t)(vtrn1q_f16(*s0, *s1)); + float32x4_t m1 = (float32x4_t)(vtrn2q_f16(*s0, *s1)); + float32x4_t m2 = (float32x4_t)(vtrn1q_f16(*s2, *s3)); + float32x4_t m3 = (float32x4_t)(vtrn2q_f16(*s2, *s3)); + float32x4_t m4 = (float32x4_t)(vtrn1q_f16(*s4, *s5)); + float32x4_t m5 = (float32x4_t)(vtrn2q_f16(*s4, *s5)); + float32x4_t m6 = (float32x4_t)(vtrn1q_f16(*s6, *s7)); + float32x4_t m7 = (float32x4_t)(vtrn2q_f16(*s6, *s7)); + + float64x2_t t0 = (float64x2_t)(vtrn1q_f32(m0, m2)); + float64x2_t t2 = (float64x2_t)(vtrn2q_f32(m0, m2)); + float64x2_t t1 = (float64x2_t)(vtrn1q_f32(m1, m3)); + float64x2_t t3 = (float64x2_t)(vtrn2q_f32(m1, m3)); + float64x2_t t4 = (float64x2_t)(vtrn1q_f32(m4, m6)); + float64x2_t t6 = (float64x2_t)(vtrn2q_f32(m4, m6)); + float64x2_t t5 = (float64x2_t)(vtrn1q_f32(m5, m7)); + float64x2_t t7 = (float64x2_t)(vtrn2q_f32(m5, m7)); + + *s0 = (float16x8_t)(vtrn1q_f64(t0, t4)); + *s4 = (float16x8_t)(vtrn2q_f64(t0, t4)); + *s1 = (float16x8_t)(vtrn1q_f64(t1, t5)); + *s5 = (float16x8_t)(vtrn2q_f64(t1, t5)); + *s2 = (float16x8_t)(vtrn1q_f64(t2, t6)); + *s6 = (float16x8_t)(vtrn2q_f64(t2, t6)); + *s3 = (float16x8_t)(vtrn1q_f64(t3, t7)); + *s7 = (float16x8_t)(vtrn2q_f64(t3, t7)); +} +#endif + static InputTransFp16Func InputTransFp16FuncList[] = { NULL, NULL, NULL, NULL, InputTransform4x4UnitFp16, NULL, InputTransform6x6UnitFp16, NULL, InputTransform8x8UnitFp16}; @@ -81,6 +113,25 @@ static OutputTransFp16Func OutputTransFp16FuncRelu6List8[] = {NULL, InputTransFp16Func GetInputTransFp16Func(int input_unit) { return InputTransFp16FuncList[input_unit]; } +#ifdef ENABLE_ARM64 +static InputTransStepFp16Func InputTransStepFp16FuncList[] = { + NULL, NULL, NULL, NULL, InputTransform4x4StepFp16, NULL, InputTransform6x6StepFp16, NULL, InputTransform8x8StepFp16}; + +static InputTransPackFp16Func InputTransPackFp16FuncList[] = {NULL, + NULL, + NULL, + NULL, + InputTransform4x4Pack16Fp16, + NULL, + InputTransform6x6Pack16Fp16, + NULL, + InputTransform8x8Pack16Fp16}; + +InputTransStepFp16Func GetInputTransStepFp16Func(int input_unit) { return InputTransStepFp16FuncList[input_unit]; } + +InputTransPackFp16Func GetInputTransPackFp16Func(int input_unit) { return InputTransPackFp16FuncList[input_unit]; } +#endif + void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) { int j = 0; if (real_c == 8) { @@ -160,6 +211,74 @@ void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, i } } +void InputTransform4x4StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int dst_row_step) { + for (int l = 0; l < 4; ++l) { + const float16_t *src_ptr = src_data + l * 4 * src_step; + float16_t *dst_ptr = dst_data + l * dst_row_step; + + float16x8_t s0 = vld1q_f16(src_ptr + 0 * src_step); + float16x8_t s1 = vld1q_f16(src_ptr + 1 * src_step); + float16x8_t s2 = vld1q_f16(src_ptr + 2 * src_step); + float16x8_t s3 = vld1q_f16(src_ptr + 3 * src_step); + float16x8_t m0 = vsubq_f16(s0, s2); + float16x8_t m1 = vaddq_f16(s1, s2); + float16x8_t m2 = vsubq_f16(s2, s1); + float16x8_t m3 = vsubq_f16(s3, s1); + + vst1q_f16(dst_ptr + 0 * dst_step, m0); + vst1q_f16(dst_ptr + 1 * dst_step, m1); + vst1q_f16(dst_ptr + 2 * dst_step, m2); + vst1q_f16(dst_ptr + 3 * dst_step, m3); + } +} + +#ifdef ENABLE_ARM64 +void InputTransform4x4Pack16ChannelFp16(float16_t *src_ptr, float16_t *dst_ptr, int dst_step, int pack_tile, + int src_point_stride) { + LOAD_LINE_DATA_FP16(0); + LOAD_LINE_DATA_FP16(1); + LOAD_LINE_DATA_FP16(2); + LOAD_LINE_DATA_FP16(3); + + float16x8_t m0 = vsubq_f16(s00, s20); + float16x8_t m1 = vsubq_f16(s01, s21); + vst1q_f16(dst_ptr + 0 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 0 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(s10, s20); + m1 = vaddq_f16(s11, s21); + vst1q_f16(dst_ptr + 1 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 1 * dst_step + 1 * pack_tile, m1); + + m0 = vsubq_f16(s20, s10); + m1 = vsubq_f16(s21, s11); + vst1q_f16(dst_ptr + 2 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 2 * dst_step + 1 * pack_tile, m1); + + m0 = vsubq_f16(s30, s10); + m1 = vsubq_f16(s31, s11); + vst1q_f16(dst_ptr + 3 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 3 * dst_step + 1 * pack_tile, m1); +} + +void InputTransform4x4Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) { + int block_tile = 16; + int pack_tile = src_step; + int src_point_stride = block_tile * pack_tile; + for (int l = 0; l < 4; ++l) { + float16_t *src_ptr = src_data + l * C8NUM * block_tile; + TRANSPOSE_16x8; + } + + for (int c = 0; c < real_c; ++c) { + float16_t *src_ptr = src_data + c * block_tile; + float16_t *dst_ptr = dst_data + c * block_tile; + InputTransform4x4Pack16ChannelFp16(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride); + } +} +#endif + void InputTransform6x6UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) { int j = 0; if (real_c == 8) { @@ -272,6 +391,95 @@ void InputTransform6x6UnitFp16(const float16_t *src_data, float16_t *dst_data, i } } +void InputTransform6x6StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int dst_row_step) { + for (int l = 0; l < 6; ++l) { + const float16_t *src_ptr = src_data + l * 6 * src_step; + float16_t *dst_ptr = dst_data + l * dst_row_step; + + float16x8_t s0 = vld1q_f16(src_ptr + 0 * src_step); + float16x8_t s1 = vld1q_f16(src_ptr + 1 * src_step); + float16x8_t s2 = vld1q_f16(src_ptr + 2 * src_step); + float16x8_t s3 = vld1q_f16(src_ptr + 3 * src_step); + float16x8_t s4 = vld1q_f16(src_ptr + 4 * src_step); + float16x8_t s5 = vld1q_f16(src_ptr + 5 * src_step); + + float16x8_t tmp1 = vsubq_f16(s3, s1); + float16x8_t tmp2 = vsubq_f16(s4, s2); + float16x8_t m0 = vaddq_f16(vsubq_f16(vmulq_n_f16(s0, 4), vmulq_n_f16(s2, 5)), s4); + float16x8_t m1 = vaddq_f16(vmulq_n_f16(vaddq_f16(s1, s2), -4), vaddq_f16(s3, s4)); + float16x8_t m2 = vaddq_f16(vmulq_n_f16(vsubq_f16(s1, s2), 4), vsubq_f16(s4, s3)); + float16x8_t m3 = vaddq_f16(vmulq_n_f16(tmp1, 2), tmp2); + float16x8_t m4 = vaddq_f16(vmulq_n_f16(tmp1, -2), tmp2); + float16x8_t m5 = vaddq_f16(vsubq_f16(vmulq_n_f16(s1, 4), vmulq_n_f16(s3, 5)), s5); + + vst1q_f16(dst_ptr + 0 * dst_step, m0); + vst1q_f16(dst_ptr + 1 * dst_step, m1); + vst1q_f16(dst_ptr + 2 * dst_step, m2); + vst1q_f16(dst_ptr + 3 * dst_step, m3); + vst1q_f16(dst_ptr + 4 * dst_step, m4); + vst1q_f16(dst_ptr + 5 * dst_step, m5); + } +} + +#ifdef ENABLE_ARM64 +void InputTransform6x6Pack16ChannelFp16(float16_t *src_ptr, float16_t *dst_ptr, int dst_step, int pack_tile, + int src_point_stride) { + LOAD_LINE_DATA_FP16(0); + LOAD_LINE_DATA_FP16(1); + LOAD_LINE_DATA_FP16(2); + LOAD_LINE_DATA_FP16(3); + LOAD_LINE_DATA_FP16(4); + LOAD_LINE_DATA_FP16(5); + + float16x8_t m0 = vaddq_f16(vsubq_f16(vmulq_n_f16(s00, 4), vmulq_n_f16(s20, 5)), s40); + float16x8_t m1 = vaddq_f16(vsubq_f16(vmulq_n_f16(s01, 4), vmulq_n_f16(s21, 5)), s41); + vst1q_f16(dst_ptr + 0 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 0 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vmulq_n_f16(vaddq_f16(s10, s20), -4), vaddq_f16(s30, s40)); + m1 = vaddq_f16(vmulq_n_f16(vaddq_f16(s11, s21), -4), vaddq_f16(s31, s41)); + vst1q_f16(dst_ptr + 1 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 1 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vmulq_n_f16(vsubq_f16(s10, s20), 4), vsubq_f16(s40, s30)); + m1 = vaddq_f16(vmulq_n_f16(vsubq_f16(s11, s21), 4), vsubq_f16(s41, s31)); + vst1q_f16(dst_ptr + 2 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 2 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vmulq_n_f16(vsubq_f16(s30, s10), 2), vsubq_f16(s40, s20)); + m1 = vaddq_f16(vmulq_n_f16(vsubq_f16(s31, s11), 2), vsubq_f16(s41, s21)); + vst1q_f16(dst_ptr + 3 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 3 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vmulq_n_f16(vsubq_f16(s30, s10), -2), vsubq_f16(s40, s20)); + m1 = vaddq_f16(vmulq_n_f16(vsubq_f16(s31, s11), -2), vsubq_f16(s41, s21)); + vst1q_f16(dst_ptr + 4 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 4 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vsubq_f16(vmulq_n_f16(s10, 4), vmulq_n_f16(s30, 5)), s50); + m1 = vaddq_f16(vsubq_f16(vmulq_n_f16(s11, 4), vmulq_n_f16(s31, 5)), s51); + vst1q_f16(dst_ptr + 5 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 5 * dst_step + 1 * pack_tile, m1); +} + +void InputTransform6x6Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) { + int block_tile = 16; + int pack_tile = src_step; + int src_point_stride = block_tile * pack_tile; + for (int l = 0; l < 6; ++l) { + float16_t *src_ptr = src_data + l * C8NUM * block_tile; + TRANSPOSE_16x8; + } + + for (int c = 0; c < real_c; ++c) { + float16_t *src_ptr = src_data + c * block_tile; + float16_t *dst_ptr = dst_data + c * block_tile; + InputTransform6x6Pack16ChannelFp16(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride); + } +} +#endif + void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) { int j = 0; if (real_c == 8) { @@ -429,6 +637,133 @@ void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, i } } +void InputTransform8x8StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int dst_row_step) { + for (int l = 0; l < 8; ++l) { + const float16_t *src_ptr = src_data + l * 8 * src_step; + float16_t *dst_ptr = dst_data + l * dst_row_step; + + float16x8_t s0 = vld1q_f16(src_ptr + 0 * src_step); + float16x8_t s1 = vld1q_f16(src_ptr + 1 * src_step); + float16x8_t s2 = vld1q_f16(src_ptr + 2 * src_step); + float16x8_t s3 = vld1q_f16(src_ptr + 3 * src_step); + float16x8_t s4 = vld1q_f16(src_ptr + 4 * src_step); + float16x8_t s5 = vld1q_f16(src_ptr + 5 * src_step); + float16x8_t s6 = vld1q_f16(src_ptr + 6 * src_step); + float16x8_t s7 = vld1q_f16(src_ptr + 7 * src_step); + + float16x8_t m0 = + vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(s0, 0.5625), vmulq_n_f16(s2, 3.0625)), vmulq_n_f16(s4, 3.5)), s6); + float16x8_t tmp1 = vaddq_f16(vmulq_n_f16(s1, 1.125), vmulq_n_f16(s5, 0.5)); + float16x8_t tmp2 = vsubq_f16(vmulq_n_f16(s2, 2.25), vmulq_n_f16(s4, 3.25)); + float16x8_t m1 = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(s3, 1.625)), s6); + float16x8_t m2 = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(s3, 1.625)), s6); + tmp1 = vaddq_f16(vmulq_n_f16(s1, 0.5625), s5); + tmp2 = vsubq_f16(vmulq_n_f16(s2, 0.5625), vmulq_n_f16(s4, 2.5)); + float16x8_t m3 = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(s3, 2.5)), s6); + float16x8_t m4 = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(s3, 2.5)), s6); + tmp1 = vaddq_f16(vmulq_n_f16(s1, 0.375), vmulq_n_f16(s5, 1.5)); + tmp2 = vsubq_f16(vmulq_n_f16(s2, 0.25), vmulq_n_f16(s4, 1.25)); + float16x8_t m5 = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(s3, 1.875)), s6); + float16x8_t m6 = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(s3, 1.875)), s6); + float16x8_t m7 = + vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(s1, -0.5625), vmulq_n_f16(s3, 3.0625)), vmulq_n_f16(s5, 3.5)), s7); + + vst1q_f16(dst_ptr + 0 * dst_step, m0); + vst1q_f16(dst_ptr + 1 * dst_step, m1); + vst1q_f16(dst_ptr + 2 * dst_step, m2); + vst1q_f16(dst_ptr + 3 * dst_step, m3); + vst1q_f16(dst_ptr + 4 * dst_step, m4); + vst1q_f16(dst_ptr + 5 * dst_step, m5); + vst1q_f16(dst_ptr + 6 * dst_step, m6); + vst1q_f16(dst_ptr + 7 * dst_step, m7); + } +} + +#ifdef ENABLE_ARM64 +void InputTransform8x8Pack16ChannelFp16(float16_t *src_ptr, float16_t *dst_ptr, int dst_step, int pack_tile, + int src_point_stride) { + LOAD_LINE_DATA_FP16(0); + LOAD_LINE_DATA_FP16(1); + LOAD_LINE_DATA_FP16(2); + LOAD_LINE_DATA_FP16(3); + LOAD_LINE_DATA_FP16(4); + LOAD_LINE_DATA_FP16(5); + LOAD_LINE_DATA_FP16(6); + LOAD_LINE_DATA_FP16(7); + + float16x8_t m0 = + vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(s00, 0.5625), vmulq_n_f16(s20, 3.0625)), vmulq_n_f16(s40, 3.5)), s60); + float16x8_t m1 = + vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(s01, 0.5625), vmulq_n_f16(s21, 3.0625)), vmulq_n_f16(s41, 3.5)), s61); + vst1q_f16(dst_ptr + 0 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 0 * dst_step + 1 * pack_tile, m1); + + float16x8_t tmp10 = vaddq_f16(vmulq_n_f16(s10, 1.125), vmulq_n_f16(s50, 0.5)); + float16x8_t tmp11 = vaddq_f16(vmulq_n_f16(s11, 1.125), vmulq_n_f16(s51, 0.5)); + float16x8_t tmp20 = vsubq_f16(vmulq_n_f16(s20, 2.25), vmulq_n_f16(s40, 3.25)); + float16x8_t tmp21 = vsubq_f16(vmulq_n_f16(s21, 2.25), vmulq_n_f16(s41, 3.25)); + m0 = vaddq_f16(vsubq_f16(vaddq_f16(tmp10, tmp20), vmulq_n_f16(s30, 1.625)), s60); + m1 = vaddq_f16(vsubq_f16(vaddq_f16(tmp11, tmp21), vmulq_n_f16(s31, 1.625)), s61); + vst1q_f16(dst_ptr + 1 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 1 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vaddq_f16(vsubq_f16(tmp20, tmp10), vmulq_n_f16(s30, 1.625)), s60); + m1 = vaddq_f16(vaddq_f16(vsubq_f16(tmp21, tmp11), vmulq_n_f16(s31, 1.625)), s61); + vst1q_f16(dst_ptr + 2 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 2 * dst_step + 1 * pack_tile, m1); + + tmp10 = vaddq_f16(vmulq_n_f16(s10, 0.5625), s50); + tmp11 = vaddq_f16(vmulq_n_f16(s11, 0.5625), s51); + tmp20 = vsubq_f16(vmulq_n_f16(s20, 0.5625), vmulq_n_f16(s40, 2.5)); + tmp21 = vsubq_f16(vmulq_n_f16(s21, 0.5625), vmulq_n_f16(s41, 2.5)); + m0 = vaddq_f16(vsubq_f16(vaddq_f16(tmp10, tmp20), vmulq_n_f16(s30, 2.5)), s60); + m1 = vaddq_f16(vsubq_f16(vaddq_f16(tmp11, tmp21), vmulq_n_f16(s31, 2.5)), s61); + vst1q_f16(dst_ptr + 3 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 3 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vaddq_f16(vsubq_f16(tmp20, tmp10), vmulq_n_f16(s30, 2.5)), s60); + m1 = vaddq_f16(vaddq_f16(vsubq_f16(tmp21, tmp11), vmulq_n_f16(s31, 2.5)), s61); + vst1q_f16(dst_ptr + 4 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 4 * dst_step + 1 * pack_tile, m1); + + tmp10 = vaddq_f16(vmulq_n_f16(s10, 0.375), vmulq_n_f16(s50, 1.5)); + tmp11 = vaddq_f16(vmulq_n_f16(s11, 0.375), vmulq_n_f16(s51, 1.5)); + tmp20 = vsubq_f16(vmulq_n_f16(s20, 0.25), vmulq_n_f16(s40, 1.25)); + tmp21 = vsubq_f16(vmulq_n_f16(s21, 0.25), vmulq_n_f16(s41, 1.25)); + m0 = vaddq_f16(vsubq_f16(vaddq_f16(tmp10, tmp20), vmulq_n_f16(s30, 1.875)), s60); + m1 = vaddq_f16(vsubq_f16(vaddq_f16(tmp11, tmp21), vmulq_n_f16(s31, 1.875)), s61); + vst1q_f16(dst_ptr + 5 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 5 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vaddq_f16(vsubq_f16(tmp20, tmp10), vmulq_n_f16(s30, 1.875)), s60); + m1 = vaddq_f16(vaddq_f16(vsubq_f16(tmp21, tmp11), vmulq_n_f16(s31, 1.875)), s61); + vst1q_f16(dst_ptr + 6 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 6 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(s10, -0.5625), vmulq_n_f16(s30, 3.0625)), vmulq_n_f16(s50, 3.5)), s70); + m1 = vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(s11, -0.5625), vmulq_n_f16(s31, 3.0625)), vmulq_n_f16(s51, 3.5)), s71); + vst1q_f16(dst_ptr + 7 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 7 * dst_step + 1 * pack_tile, m1); +} + +void InputTransform8x8Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) { + int block_tile = 16; + int pack_tile = src_step; + int src_point_stride = block_tile * pack_tile; + for (int l = 0; l < 8; ++l) { + float16_t *src_ptr = src_data + l * C8NUM * block_tile; + TRANSPOSE_16x8; + } + + for (int c = 0; c < real_c; ++c) { + float16_t *src_ptr = src_data + c * block_tile; + float16_t *dst_ptr = dst_data + c * block_tile; + InputTransform8x8Pack16ChannelFp16(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride); + } +} +#endif + OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActType act_type) { if (input_unit == 4 && output_unit < 4) { if (act_type == ActType_Relu) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/winograd_utils_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/winograd_utils_fp16.h index dfae3fb1182..345fee63c9e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/winograd_utils_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/winograd_utils_fp16.h @@ -29,9 +29,22 @@ extern "C" { typedef void (*InputTransFp16Func)(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); +typedef void (*InputTransStepFp16Func)(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int dst_row_step); + +typedef void (*InputTransPackFp16Func)(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int real_c); + typedef void (*OutputTransFp16Func)(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +typedef struct TransFp16FuncList { + InputTransFp16Func in_func_; + InputTransStepFp16Func in_step_func_; + InputTransPackFp16Func in_pack_func_; + OutputTransFp16Func out_func_; +} TransFp16FuncList; + #define Load16DataFp16 \ src[0] = vld1q_f16(src_data + 0 * src_step); \ src[1] = vld1q_f16(src_data + 1 * src_step); \ @@ -276,14 +289,77 @@ typedef void (*OutputTransFp16Func)(const float16_t *src_data, float16_t *dst_da src[62] = vld1_f16(src_data + 62 * src_step); \ src[63] = vld1_f16(src_data + 63 * src_step); +#define LOAD_LINE_DATA_FP16(line) \ + float16x8_t s##line##0 = vld1q_f16(src_ptr + line * src_point_stride + 0 * pack_tile); \ + float16x8_t s##line##1 = vld1q_f16(src_ptr + line * src_point_stride + 1 * pack_tile); + +#define TRANSPOSE_16x8 \ + float16x8_t s0 = vld1q_f16(src_ptr + 0 * pack_tile); \ + float16x8_t s2 = vld1q_f16(src_ptr + 1 * pack_tile); \ + float16x8_t s4 = vld1q_f16(src_ptr + 2 * pack_tile); \ + float16x8_t s6 = vld1q_f16(src_ptr + 3 * pack_tile); \ + float16x8_t s8 = vld1q_f16(src_ptr + 4 * pack_tile); \ + float16x8_t s10 = vld1q_f16(src_ptr + 5 * pack_tile); \ + float16x8_t s12 = vld1q_f16(src_ptr + 6 * pack_tile); \ + float16x8_t s14 = vld1q_f16(src_ptr + 7 * pack_tile); \ + float16x8_t s1 = vld1q_f16(src_ptr + 8 * pack_tile); \ + float16x8_t s3 = vld1q_f16(src_ptr + 9 * pack_tile); \ + float16x8_t s5 = vld1q_f16(src_ptr + 10 * pack_tile); \ + float16x8_t s7 = vld1q_f16(src_ptr + 11 * pack_tile); \ + float16x8_t s9 = vld1q_f16(src_ptr + 12 * pack_tile); \ + float16x8_t s11 = vld1q_f16(src_ptr + 13 * pack_tile); \ + float16x8_t s13 = vld1q_f16(src_ptr + 14 * pack_tile); \ + float16x8_t s15 = vld1q_f16(src_ptr + 15 * pack_tile); \ + transpose8(&s0, &s2, &s4, &s6, &s8, &s10, &s12, &s14); \ + transpose8(&s1, &s3, &s5, &s7, &s9, &s11, &s13, &s15); \ + vst1q_f16(src_ptr + 0 * pack_tile, s0); \ + vst1q_f16(src_ptr + 1 * pack_tile, s1); \ + vst1q_f16(src_ptr + 2 * pack_tile, s2); \ + vst1q_f16(src_ptr + 3 * pack_tile, s3); \ + vst1q_f16(src_ptr + 4 * pack_tile, s4); \ + vst1q_f16(src_ptr + 5 * pack_tile, s5); \ + vst1q_f16(src_ptr + 6 * pack_tile, s6); \ + vst1q_f16(src_ptr + 7 * pack_tile, s7); \ + vst1q_f16(src_ptr + 8 * pack_tile, s8); \ + vst1q_f16(src_ptr + 9 * pack_tile, s9); \ + vst1q_f16(src_ptr + 10 * pack_tile, s10); \ + vst1q_f16(src_ptr + 11 * pack_tile, s11); \ + vst1q_f16(src_ptr + 12 * pack_tile, s12); \ + vst1q_f16(src_ptr + 13 * pack_tile, s13); \ + vst1q_f16(src_ptr + 14 * pack_tile, s14); \ + vst1q_f16(src_ptr + 15 * pack_tile, s15); + InputTransFp16Func GetInputTransFp16Func(int input_unit); +#ifdef ENABLE_ARM64 +InputTransStepFp16Func GetInputTransStepFp16Func(int input_unit); + +InputTransPackFp16Func GetInputTransPackFp16Func(int input_unit); +#endif + void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); void InputTransform6x6UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); +void InputTransform4x4StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int dst_row_step); + +void InputTransform6x6StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int dst_row_step); + +void InputTransform8x8StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int dst_row_step); + +#ifdef ENABLE_ARM64 +void InputTransform4x4Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform6x6Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform8x8Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); +#endif + OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActType act_type); #define Store4DataFp16 \ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_winograd_fp32.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_winograd_fp32.c index 65d20b9742d..bf8205b61a7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_winograd_fp32.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_winograd_fp32.c @@ -23,11 +23,12 @@ // fp32 conv winograd void ConvWinogardFp32(const float *input_data, const float *trans_weight, const float *bias_data, float *output_data, TmpBufferAddress *buffer_list, int task_id, const ConvParameter *conv_param, - InputTransFunc in_func, OutputTransFunc out_func) { + TransFuncList trans_func) { if (conv_param->output_unit_ == 0) { return; } int in_channel = conv_param->input_channel_; + int input_unit = conv_param->input_unit_; int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_); int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_); int output_count = out_w_block * out_h_block; @@ -35,26 +36,19 @@ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const int output_tile_count = UP_DIV(output_count, tile_num); #ifdef ENABLE_AVX const int col_tile = C16NUM; - const int tmp_data_tile = C8NUM; + const int channel_pack_tile = C8NUM; #else const int col_tile = C8NUM; - const int tmp_data_tile = C4NUM; + const int channel_pack_tile = C4NUM; #endif int oc_tile = UP_DIV(conv_param->output_channel_, col_tile); int oc8 = UP_DIV(conv_param->output_channel_, C8NUM); - int input_unit_square = conv_param->input_unit_ * conv_param->input_unit_; - if (input_unit_square < conv_param->input_unit_) { - return; - } + int input_unit_square = input_unit * input_unit; - float *trans_input = buffer_list[0]; - float *gemm_out = buffer_list[1]; - float *tmp_data = buffer_list[2]; - float *col_buffer = buffer_list[3]; - int trans_input_offset = tile_num * input_unit_square * in_channel; - int gemm_out_offset = tile_num * input_unit_square * oc8 * C8NUM; - int tmp_data_offset = input_unit_square * tmp_data_tile; - int col_buffer_offset = tile_num * in_channel; + float *trans_input = buffer_list[0] + task_id * tile_num * input_unit_square * in_channel; + float *gemm_out = buffer_list[1] + task_id * tile_num * input_unit_square * oc8 * C8NUM; + float *tmp_data = buffer_list[2] + task_id * input_unit_square * channel_pack_tile; + float *col_buffer = buffer_list[3] + task_id * tile_num * in_channel; // step 1 : filter transform (pre-processed offline) // step 2 : input transform (online) for (int b = 0; b < conv_param->input_batch_; b++) { @@ -67,37 +61,75 @@ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const if (cal_num <= 0) { return; } - WinogradInputTransform(input_data + in_batch_offset, trans_input + task_id * trans_input_offset, - tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param, - in_func); - // step 3 : gemm - float *src_ptr = trans_input + task_id * trans_input_offset; - float *dst_ptr = gemm_out + task_id * gemm_out_offset; - float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset; - for (int i = 0; i < input_unit_square; ++i) { + +#ifdef ENABLE_ARM64 + // Optimize input transform. Only valid for arm64, the tile num is 12, the channel_tile is 4. + // For arm32, the tile_num is 4. + // For x86_sse, the tile_num is 4, the channel_tile is 4. + // For avx, the tile_num is 6, the channel_tile is 8. + // N = input_unit, M = tile_num + // The function(InputTransformNxNStep, InputTransform4x4PackM) needs to be rewritten. + bool fused_pack = + (cal_num == tile_num) && (trans_func.in_step_func_ != NULL) && (trans_func.in_pack_func_ != NULL); + if (fused_pack) { + float *opt_trans_input = + buffer_list[4] + task_id * tile_num * input_unit_square * UP_ROUND(in_channel, channel_pack_tile); + WinogradInputTransformOptStep(input_data + in_batch_offset, opt_trans_input, tmp_data, cal_num, out_tile_index, + out_w_block, conv_param, trans_func.in_step_func_); + + for (int w_index = 0; w_index < input_unit; w_index++) { + float *src_w = opt_trans_input + w_index * input_unit * tile_num * channel_pack_tile; + for (int c = 0; c < UP_DIV(in_channel, channel_pack_tile); c++) { + int real_c = in_channel - c * channel_pack_tile; + real_c = real_c > channel_pack_tile ? channel_pack_tile : real_c; + float *src_c = src_w + c * input_unit_square * tile_num * channel_pack_tile; + float *dst_c = trans_input + c * tile_num * channel_pack_tile; + trans_func.in_pack_func_(src_c, dst_c, channel_pack_tile, in_channel * tile_num, real_c); + } + + for (int h_index = 0; h_index < input_unit; h_index++) { + const float *gemm_input = trans_input + h_index * tile_num * in_channel; + int point_index = h_index * input_unit + w_index; + const float *gemm_weight = trans_weight + point_index * in_channel * oc_tile * col_tile; + MatMulOpt(gemm_input, gemm_weight, gemm_out + point_index * C8NUM, NULL, 0, in_channel, cal_num, + oc8 * C8NUM, input_unit_square, OutType_TileC8); + } + } + } else { +#endif + WinogradInputTransform(input_data + in_batch_offset, trans_input, tmp_data, cal_num, out_tile_index, + out_w_block, conv_param, trans_func.in_func_); + // step 3 : gemm + float *src_ptr = trans_input; + float *dst_ptr = gemm_out; + float *tmp_col_ptr = col_buffer; + for (int i = 0; i < input_unit_square; ++i) { #ifdef ENABLE_AVX - RowMajor2Col6Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); + RowMajor2Col6Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) RowMajor2Col4Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); #else RowMajor2Col12Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); #endif - MatMulOpt(tmp_col_ptr, trans_weight + i * in_channel * oc_tile * col_tile, dst_ptr + i * C8NUM, NULL, 0, - in_channel, cal_num, oc8 * C8NUM, input_unit_square, 2); + MatMulOpt(tmp_col_ptr, trans_weight + i * in_channel * oc_tile * col_tile, dst_ptr + i * C8NUM, NULL, 0, + in_channel, cal_num, oc8 * C8NUM, input_unit_square, 2); + } +#ifdef ENABLE_ARM64 } +#endif // step 4 : output transform float *output_ptr = output_data + out_batch_offset; if (conv_param->out_format_ != NNACL_NC4HW4) { // nc4hw4 - WinogradOutputNHWCTransform(dst_ptr, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param, - out_func); + WinogradOutputNHWCTransform(gemm_out, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param, + trans_func.out_func_); } else { #if defined(ENABLE_AVX) || defined(ENABLE_ARM64) - WinogradOutputNC4HW4Transform(dst_ptr, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param, - out_func); + WinogradOutputNC4HW4Transform(gemm_out, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param, + trans_func.out_func_); #else - WinogradOutputNHWCTransform(dst_ptr, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param, - out_func); + WinogradOutputNHWCTransform(gemm_out, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param, + trans_func.out_func_); #endif } } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_winograd_fp32.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_winograd_fp32.h index 2b84d3e77ad..cedc503869e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_winograd_fp32.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_winograd_fp32.h @@ -36,7 +36,7 @@ extern "C" { // fp32 convolution winograd void ConvWinogardFp32(const float *input_data, const float *trans_weight, const float *bias_data, float *output_data, TmpBufferAddress *buffer_list, int task_id, const ConvParameter *conv_param, - InputTransFunc in_func, OutputTransFunc out_func); + TransFuncList trans_func); #ifdef __cplusplus } #endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_transform.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_transform.c index 73c1b7acee6..1fd085eddb1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_transform.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_transform.c @@ -17,6 +17,59 @@ #include "nnacl/fp32/winograd_transform.h" #include "nnacl/op_base.h" +void PrepareTransInput(const float *src_data, float *dst_data, int interval_x_s, int interval_x_e, int interval_y_s, + int interval_y_e, int real_c, const ConvParameter *conv_param) { + int input_unit = conv_param->input_unit_; + int in_channel = conv_param->input_channel_; + int input_w = conv_param->input_w_; +#ifdef ENABLE_AVX + int channel_tile = C8NUM; +#else + int channel_tile = C4NUM; +#endif + // clear tmp buffer + if (interval_x_e - interval_x_s != input_unit || interval_y_e - interval_y_s != input_unit) { + memset(dst_data, 0, input_unit * input_unit * channel_tile * (int)(sizeof(float))); + } + + // get real input block with padding + if (real_c == channel_tile) { + for (int interval = interval_y_s; interval < interval_y_e; interval++) { + int src_y_offset = (interval * input_w + interval_x_s) * in_channel; + int dst_y_offset = interval * input_unit * channel_tile + interval_x_s * channel_tile; + for (int j = 0; j < (interval_x_e - interval_x_s); j++) { + int src_x_offset = src_y_offset + j * in_channel; + int dst_x_offset = dst_y_offset + j * channel_tile; + const float *src_addr = src_data + src_x_offset; + float *dst_addr = dst_data + dst_x_offset; +#ifdef ENABLE_AVX + MS_ST256_F32(dst_addr, MS_LD256_F32(src_addr)); +#elif defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_STQ_F32(dst_addr, MS_LDQ_F32(src_addr)); +#else + for (int k = 0; k < channel_tile; k++) { + dst_addr[k] = src_addr[k]; + } +#endif + } // interval x loop + } // interval y loop + } else { + for (int interval = interval_y_s; interval < interval_y_e; interval++) { + int src_y_offset = (interval * input_w + interval_x_s) * in_channel; + int dst_y_offset = interval * input_unit * channel_tile + interval_x_s * channel_tile; + for (int j = 0; j < (interval_x_e - interval_x_s); j++) { + int src_x_offset = src_y_offset + j * in_channel; + int dst_x_offset = dst_y_offset + j * channel_tile; + const float *src_addr = src_data + src_x_offset; + float *dst_addr = dst_data + dst_x_offset; + for (int k = 0; k < real_c; k++) { + dst_addr[k] = src_addr[k]; + } + } // interval x loop + } // interval y loop + } +} + // fp32 conv winograd void WinogradInputTransform(const float *input_data, float *trans_input, float *tmp_data, int cal_num, int out_tile_index, int out_w_block_num, const ConvParameter *conv_param, @@ -25,11 +78,11 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float * int output_unit = conv_param->output_unit_; int in_channel = conv_param->input_channel_; #ifdef ENABLE_AVX - int tile = C8NUM; + int channel_tile = C8NUM; #else - int tile = C4NUM; + int channel_tile = C4NUM; #endif - int ic4 = UP_DIV(in_channel, tile); + int ic4 = UP_DIV(in_channel, channel_tile); int pad_h = conv_param->pad_u_; int pad_w = conv_param->pad_l_; int input_h = conv_param->input_h_; @@ -49,54 +102,61 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float * int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s); int dst_plane_offset = c * in_channel; for (int ic = 0; ic < ic4; ic++) { - // clear tmp buffer - memset(tmp_data, 0, input_unit * input_unit * tile * (int)(sizeof(float))); + int real_c = in_channel - ic * channel_tile; + real_c = real_c > channel_tile ? channel_tile : real_c; + const float *src_data = input_data + src_plane_offset + ic * channel_tile; + PrepareTransInput(src_data, tmp_data, interval_x_s, interval_x_e, interval_y_s, interval_y_e, real_c, conv_param); - int real_c = in_channel - ic * tile; - real_c = real_c > tile ? tile : real_c; - int src_ic4_offset = src_plane_offset + ic * tile; - // get real input block with padding - if (real_c == tile) { - for (int interval = interval_y_s; interval < interval_y_e; interval++) { - int src_y_offset = src_ic4_offset + (interval * input_w + interval_x_s) * in_channel; - int dst_y_offset = interval * input_unit * tile + interval_x_s * tile; - for (int j = 0; j < (interval_x_e - interval_x_s); j++) { - int src_x_offset = src_y_offset + j * in_channel; - int dst_x_offset = dst_y_offset + j * tile; - float *src_addr = (float *)(input_data) + src_x_offset; - float *dst_addr = tmp_data + dst_x_offset; -#ifdef ENABLE_AVX - MS_ST256_F32(dst_addr, MS_LD256_F32(src_addr)); -#elif defined(ENABLE_ARM) || defined(ENABLE_SSE) - MS_STQ_F32(dst_addr, MS_LDQ_F32(src_addr)); -#else - for (int k = 0; k < tile; k++) { - dst_addr[k] = src_addr[k]; - } -#endif - } // interval x loop - } // interval y loop - } else { - for (int interval = interval_y_s; interval < interval_y_e; interval++) { - int src_y_offset = src_ic4_offset + (interval * input_w + interval_x_s) * in_channel; - int dst_y_offset = interval * input_unit * tile + interval_x_s * tile; - for (int j = 0; j < (interval_x_e - interval_x_s); j++) { - int src_x_offset = src_y_offset + j * in_channel; - int dst_x_offset = dst_y_offset + j * tile; - float *src_addr = (float *)(input_data) + src_x_offset; - float *dst_addr = tmp_data + dst_x_offset; - for (int k = 0; k < real_c; k++) { - dst_addr[k] = src_addr[k]; - } - } // interval x loop - } // interval y loop - } // input transform const int tile_num = C12NUM; - int dst_ic4_offset = dst_plane_offset + ic * tile; + int dst_ic4_offset = dst_plane_offset + ic * channel_tile; int dst_step = tile_num * in_channel; float *trans_input_ptr = trans_input + dst_ic4_offset; - func(tmp_data, trans_input_ptr, tile, dst_step, real_c); + func(tmp_data, trans_input_ptr, channel_tile, dst_step, real_c); + } + out_tile_index++; + } // cal_tile_num loop +} + +// Only support arm64 +void WinogradInputTransformOptStep(const float *input_data, float *trans_input, float *tmp_data, int cal_num, + int out_tile_index, int out_w_block_num, const ConvParameter *conv_param, + InputTransStepFunc func) { + int input_unit = conv_param->input_unit_; + int output_unit = conv_param->output_unit_; + int in_channel = conv_param->input_channel_; + int channel_tile = C4NUM; + int ic4 = UP_DIV(in_channel, channel_tile); + int pad_h = conv_param->pad_u_; + int pad_w = conv_param->pad_l_; + int input_h = conv_param->input_h_; + int input_w = conv_param->input_w_; + NNACL_CHECK_ZERO_RETURN(out_w_block_num); + + for (int c = 0; c < cal_num; c++) { // actual tiled number + int src_x_s = (out_tile_index % out_w_block_num) * output_unit - pad_w; + int src_y_s = (out_tile_index / out_w_block_num) * output_unit - pad_h; + int interval_x_s = src_x_s > 0 ? 0 : -src_x_s; + int interval_y_s = src_y_s > 0 ? 0 : -src_y_s; + int src_x_e = src_x_s + input_unit; + int src_y_e = src_y_s + input_unit; + int interval_x_e = src_x_e < input_w ? input_unit : (input_w - src_x_s); + int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s); + + int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s); + int dst_plane_offset = c * channel_tile; + for (int ic = 0; ic < ic4; ic++) { + int real_c = in_channel - ic * channel_tile; + real_c = real_c > channel_tile ? channel_tile : real_c; + const float *src_data = input_data + src_plane_offset + ic * channel_tile; + PrepareTransInput(src_data, tmp_data, interval_x_s, interval_x_e, interval_y_s, interval_y_e, real_c, conv_param); + + // input transform + const int block_tile = C12NUM; + int dst_ic8_offset = dst_plane_offset + ic * block_tile * input_unit * input_unit * channel_tile; + size_t dst_step = input_unit * block_tile * channel_tile; + float *trans_input_ptr = trans_input + dst_ic8_offset; + func(tmp_data, trans_input_ptr, channel_tile, dst_step, block_tile * channel_tile); } out_tile_index++; } // cal_tile_num loop diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_transform.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_transform.h index ac44f169fb4..ab9bf1161f4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_transform.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_transform.h @@ -32,6 +32,10 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float * int out_tile_index, int out_w_block_num, const ConvParameter *conv_param, InputTransFunc func); +void WinogradInputTransformOptStep(const float *input_data, float *trans_input, float *tmp_data, int cal_num, + int out_tile_index, int out_w_block_num, const ConvParameter *conv_param, + InputTransStepFunc func); + void WinogradOutputNHWCTransform(const float *gemm_out, float *out_data, const float *bias_data, int cal_num, int out_tile_index, int output_unit_num, const ConvParameter *conv_param, OutputTransFunc func); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_utils.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_utils.c index 6df3a4e94cc..c1220996dbb 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_utils.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_utils.c @@ -20,6 +20,19 @@ #include "nnacl/base/conv_common_base.h" #include "nnacl/errorcode.h" +#ifdef ENABLE_ARM64 +void transpose4(MS_FLOAT32X4 *s0, MS_FLOAT32X4 *s1, MS_FLOAT32X4 *s2, MS_FLOAT32X4 *s3) { + float64x2_t m0 = (float64x2_t)(vtrn1q_f32(*s0, *s1)); + float64x2_t m1 = (float64x2_t)(vtrn2q_f32(*s0, *s1)); + float64x2_t m2 = (float64x2_t)(vtrn1q_f32(*s2, *s3)); + float64x2_t m3 = (float64x2_t)(vtrn2q_f32(*s2, *s3)); + *s0 = (float32x4_t)(vtrn1q_f64(m0, m2)); + *s2 = (float32x4_t)(vtrn2q_f64(m0, m2)); + *s1 = (float32x4_t)(vtrn1q_f64(m1, m3)); + *s3 = (float32x4_t)(vtrn2q_f64(m1, m3)); +} +#endif + #ifdef ENABLE_AVX static InputTransFunc InputTransFuncList[] = { NULL, NULL, NULL, NULL, InputTransform4x4AvxUnit, NULL, InputTransform6x6AvxUnit, NULL, InputTransform8x8AvxUnit}; @@ -55,6 +68,18 @@ static OutputTransFunc OutputTransFuncList[] = { InputTransFunc GetInputTransFunc(int input_unit) { return InputTransFuncList[input_unit]; } +#ifdef ENABLE_ARM64 +static InputTransStepFunc InputTransStepFuncList[] = { + NULL, NULL, NULL, NULL, InputTransform4x4Step, NULL, InputTransform6x6Step, NULL, InputTransform8x8Step}; + +static InputTransPackFunc InputTransPackFuncList[] = { + NULL, NULL, NULL, NULL, InputTransform4x4Pack12, NULL, InputTransform6x6Pack12, NULL, InputTransform8x8Pack12}; + +InputTransStepFunc GetInputTransStepFunc(int input_unit) { return InputTransStepFuncList[input_unit]; } + +InputTransPackFunc GetInputTransPackFunc(int input_unit) { return InputTransPackFuncList[input_unit]; } +#endif + void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { #if defined(ENABLE_ARM) || defined(ENABLE_SSE) if (real_c == 4) { @@ -138,6 +163,136 @@ void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, #endif } +void InputTransform4x4Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step) { +#ifdef ENABLE_ARM64 + for (int l = 0; l < 4; ++l) { + const float *src_ptr = src_data + l * 4 * src_step; + float *dst_ptr = dst_data + l * dst_row_step; + + MS_FLOAT32X4 s0 = MS_LDQ_F32(src_ptr + 0 * src_step); + MS_FLOAT32X4 s1 = MS_LDQ_F32(src_ptr + 1 * src_step); + MS_FLOAT32X4 s2 = MS_LDQ_F32(src_ptr + 2 * src_step); + MS_FLOAT32X4 s3 = MS_LDQ_F32(src_ptr + 3 * src_step); + MS_FLOAT32X4 m0 = MS_SUBQ_F32(s0, s2); + MS_FLOAT32X4 m1 = MS_ADDQ_F32(s1, s2); + MS_FLOAT32X4 m2 = MS_SUBQ_F32(s2, s1); + MS_FLOAT32X4 m3 = MS_SUBQ_F32(s3, s1); + + MS_STQ_F32(dst_ptr + 0 * dst_step, m0); + MS_STQ_F32(dst_ptr + 1 * dst_step, m1); + MS_STQ_F32(dst_ptr + 2 * dst_step, m2); + MS_STQ_F32(dst_ptr + 3 * dst_step, m3); + } +#else + float src[4]; + float m[4]; + for (int i = 0; i < C4NUM; ++i) { + for (int l = 0; l < 4; ++l) { + for (int w = 0; w < 4; ++w) { + int tmp_index = l * 4 + w; + src[w] = src_data[i + tmp_index * src_step]; + } + m[0] = src[0] - src[2]; + m[1] = src[1] + src[2]; + m[2] = src[2] - src[1]; + m[3] = src[3] - src[1]; + + float *dst = dst_data + l * dst_row_step; + for (int w = 0; w < 4; ++w) { + dst[i + w * dst_step] = m[w]; + } + } + } +#endif +} + +#ifdef ENABLE_ARM64 +void InputTransform4x4Pack12Channel(float *src_ptr, float *dst_ptr, int dst_step, int pack_tile, int src_point_stride) { + LOAD_LINE_DATA(0); + LOAD_LINE_DATA(1); + LOAD_LINE_DATA(2); + LOAD_LINE_DATA(3); + + MS_FLOAT32X4 m0 = MS_SUBQ_F32(s00, s20); + MS_FLOAT32X4 m1 = MS_SUBQ_F32(s01, s21); + MS_FLOAT32X4 m2 = MS_SUBQ_F32(s02, s22); + MS_STQ_F32(dst_ptr + 0 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 0 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 0 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(s10, s20); + m1 = MS_ADDQ_F32(s11, s21); + m2 = MS_ADDQ_F32(s12, s22); + MS_STQ_F32(dst_ptr + 1 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 1 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 1 * dst_step + 2 * pack_tile, m2); + + m0 = MS_SUBQ_F32(s20, s10); + m1 = MS_SUBQ_F32(s21, s11); + m2 = MS_SUBQ_F32(s22, s12); + MS_STQ_F32(dst_ptr + 2 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 2 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 2 * dst_step + 2 * pack_tile, m2); + + m0 = MS_SUBQ_F32(s30, s10); + m1 = MS_SUBQ_F32(s31, s11); + m2 = MS_SUBQ_F32(s32, s12); + MS_STQ_F32(dst_ptr + 3 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 3 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 3 * dst_step + 2 * pack_tile, m2); +} +#endif + +void InputTransform4x4Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { + int block_tile = 12; + int pack_tile = src_step; + int src_point_stride = block_tile * pack_tile; +#ifdef ENABLE_ARM64 + for (int l = 0; l < 4; ++l) { + float *src_ptr = src_data + l * C4NUM * block_tile; + TRANSPOSE_12x4; + } + + for (int c = 0; c < real_c; ++c) { + float *src_ptr = src_data + c * block_tile; + float *dst_ptr = dst_data + c * block_tile; + InputTransform4x4Pack12Channel(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride); + } +#else + for (int l = 0; l < 4; ++l) { + float *src = src_data + l * pack_tile * block_tile; + // 12 * 4 -> 4 * 12 + float tmp_mat[pack_tile][block_tile]; + for (int i = 0; i < block_tile; ++i) { + for (int j = 0; j < pack_tile; ++j) { + tmp_mat[j][i] = src[i * pack_tile + j]; + } + } + memcpy(src, tmp_mat, pack_tile * block_tile * sizeof(float)); + } + + float src[4]; + float m[4]; + for (int c = 0; c < real_c; ++c) { + for (int i = 0; i < block_tile; ++i) { + int tmp_index = c * block_tile + i; + for (int w = 0; w < 4; ++w) { + src[w] = src_data[tmp_index + w * src_point_stride]; + } + + m[0] = src[0] - src[2]; + m[1] = src[1] + src[2]; + m[2] = src[2] - src[1]; + m[3] = src[3] - src[1]; + + for (int w = 0; w < 4; ++w) { + dst_data[tmp_index + w * dst_step] = m[w]; + } + } + } +#endif +} + void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { #if defined(ENABLE_ARM) || defined(ENABLE_SSE) if (real_c == 4) { @@ -217,6 +372,169 @@ void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, #endif } +void InputTransform6x6Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step) { +#ifdef ENABLE_ARM64 + for (int l = 0; l < 6; ++l) { + const float *src_ptr = src_data + l * 6 * src_step; + float *dst_ptr = dst_data + l * dst_row_step; + + MS_FLOAT32X4 s0 = MS_LDQ_F32(src_ptr + 0 * src_step); + MS_FLOAT32X4 s1 = MS_LDQ_F32(src_ptr + 1 * src_step); + MS_FLOAT32X4 s2 = MS_LDQ_F32(src_ptr + 2 * src_step); + MS_FLOAT32X4 s3 = MS_LDQ_F32(src_ptr + 3 * src_step); + MS_FLOAT32X4 s4 = MS_LDQ_F32(src_ptr + 4 * src_step); + MS_FLOAT32X4 s5 = MS_LDQ_F32(src_ptr + 5 * src_step); + + MS_FLOAT32X4 tmp1 = MS_SUBQ_F32(s3, s1); + MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(s4, s2); + MS_FLOAT32X4 m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s0, 4), MS_MULQ_N_F32(s2, 5)), s4); + MS_FLOAT32X4 m1 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(s1, s2), -4), MS_ADDQ_F32(s3, s4)); + MS_FLOAT32X4 m2 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s1, s2), 4), MS_SUBQ_F32(s4, s3)); + MS_FLOAT32X4 m3 = MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 2), tmp2); + MS_FLOAT32X4 m4 = MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, -2), tmp2); + MS_FLOAT32X4 m5 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s1, 4), MS_MULQ_N_F32(s3, 5)), s5); + + MS_STQ_F32(dst_ptr + 0 * dst_step, m0); + MS_STQ_F32(dst_ptr + 1 * dst_step, m1); + MS_STQ_F32(dst_ptr + 2 * dst_step, m2); + MS_STQ_F32(dst_ptr + 3 * dst_step, m3); + MS_STQ_F32(dst_ptr + 4 * dst_step, m4); + MS_STQ_F32(dst_ptr + 5 * dst_step, m5); + } +#else + float src[6]; + float m[6]; + for (int i = 0; i < C4NUM; ++i) { + for (int l = 0; l < 6; ++l) { + for (int w = 0; w < 6; ++w) { + int tmp_index = l * 6 + w; + src[w] = src_data[i + tmp_index * src_step]; + } + float tmp1 = src[3] - src[1]; + float tmp2 = src[4] - src[2]; + m[0] = 4 * src[0] - 5 * src[2] + src[4]; + m[1] = -4 * (src[1] + src[2]) + (src[3] + src[4]); + m[2] = 4 * (src[1] - src[2]) + (src[4] - src[3]); + m[3] = 2 * tmp1 + tmp2; + m[4] = -2 * tmp1 + tmp2; + m[5] = 4 * src[1] - 5 * src[3] + src[5]; + + float *dst = dst_data + l * dst_row_step; + for (int w = 0; w < 6; ++w) { + dst[i + w * dst_step] = m[w]; + } + } + } +#endif +} + +#ifdef ENABLE_ARM64 +void InputTransform6x6Pack12Channel(float *src_ptr, float *dst_ptr, int dst_step, int pack_tile, int src_point_stride) { + LOAD_LINE_DATA(0); + LOAD_LINE_DATA(1); + LOAD_LINE_DATA(2); + LOAD_LINE_DATA(3); + LOAD_LINE_DATA(4); + LOAD_LINE_DATA(5); + + MS_FLOAT32X4 m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s00, 4), MS_MULQ_N_F32(s20, 5)), s40); + MS_FLOAT32X4 m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s01, 4), MS_MULQ_N_F32(s21, 5)), s41); + MS_FLOAT32X4 m2 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s02, 4), MS_MULQ_N_F32(s22, 5)), s42); + MS_STQ_F32(dst_ptr + 0 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 0 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 0 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(s10, s20), -4), MS_ADDQ_F32(s30, s40)); + m1 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(s11, s21), -4), MS_ADDQ_F32(s31, s41)); + m2 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(s12, s22), -4), MS_ADDQ_F32(s32, s42)); + MS_STQ_F32(dst_ptr + 1 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 1 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 1 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s10, s20), 4), MS_SUBQ_F32(s40, s30)); + m1 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s11, s21), 4), MS_SUBQ_F32(s41, s31)); + m2 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s12, s22), 4), MS_SUBQ_F32(s42, s32)); + MS_STQ_F32(dst_ptr + 2 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 2 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 2 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s30, s10), 2), MS_SUBQ_F32(s40, s20)); + m1 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s31, s11), 2), MS_SUBQ_F32(s41, s21)); + m2 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s32, s12), 2), MS_SUBQ_F32(s42, s22)); + MS_STQ_F32(dst_ptr + 3 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 3 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 3 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s30, s10), -2), MS_SUBQ_F32(s40, s20)); + m1 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s31, s11), -2), MS_SUBQ_F32(s41, s21)); + m2 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s32, s12), -2), MS_SUBQ_F32(s42, s22)); + MS_STQ_F32(dst_ptr + 4 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 4 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 4 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s10, 4), MS_MULQ_N_F32(s30, 5)), s50); + m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s11, 4), MS_MULQ_N_F32(s31, 5)), s51); + m2 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s12, 4), MS_MULQ_N_F32(s32, 5)), s52); + MS_STQ_F32(dst_ptr + 5 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 5 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 5 * dst_step + 2 * pack_tile, m2); +} +#endif + +void InputTransform6x6Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { + int block_tile = 12; + int pack_tile = src_step; + int src_point_stride = block_tile * pack_tile; +#ifdef ENABLE_ARM64 + for (int l = 0; l < 6; ++l) { + float *src_ptr = src_data + l * C4NUM * block_tile; + TRANSPOSE_12x4; + } + + for (int c = 0; c < real_c; ++c) { + float *src_ptr = src_data + c * block_tile; + float *dst_ptr = dst_data + c * block_tile; + InputTransform6x6Pack12Channel(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride); + } +#else + for (int l = 0; l < 6; ++l) { + float *src = src_data + l * pack_tile * block_tile; + // 12 * 4 -> 4 * 12 + float tmp_mat[pack_tile][block_tile]; + for (int i = 0; i < block_tile; ++i) { + for (int j = 0; j < pack_tile; ++j) { + tmp_mat[j][i] = src[i * pack_tile + j]; + } + } + memcpy(src, tmp_mat, pack_tile * block_tile * sizeof(float)); + } + + float src[6]; + float m[6]; + for (int c = 0; c < real_c; ++c) { + for (int i = 0; i < block_tile; ++i) { + int tmp_index = c * block_tile + i; + for (int w = 0; w < 6; ++w) { + src[w] = src_data[tmp_index + w * src_point_stride]; + } + + float tmp1 = src[3] - src[1]; + float tmp2 = src[4] - src[2]; + m[0] = 4 * src[0] - 5 * src[2] + src[4]; + m[1] = -4 * (src[1] + src[2]) + (src[3] + src[4]); + m[2] = 4 * (src[1] - src[2]) + (src[4] - src[3]); + m[3] = 2 * tmp1 + tmp2; + m[4] = -2 * tmp1 + tmp2; + m[5] = 4 * src[1] - 5 * src[3] + src[5]; + + for (int w = 0; w < 6; ++w) { + dst_data[tmp_index + w * dst_step] = m[w]; + } + } + } +#endif +} + #if defined(ENABLE_ARM) || defined(ENABLE_SSE) void InputTransform8x8Unit_block4(const float *src_data, float *dst_data, int src_step, int dst_step) { MS_FLOAT32X4 src[64]; @@ -334,6 +652,232 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, #endif } +void InputTransform8x8Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step) { +#ifdef ENABLE_ARM64 + for (int l = 0; l < 8; ++l) { + const float *src_ptr = src_data + l * 8 * src_step; + float *dst_ptr = dst_data + l * dst_row_step; + + MS_FLOAT32X4 s0 = MS_LDQ_F32(src_ptr + 0 * src_step); + MS_FLOAT32X4 s1 = MS_LDQ_F32(src_ptr + 1 * src_step); + MS_FLOAT32X4 s2 = MS_LDQ_F32(src_ptr + 2 * src_step); + MS_FLOAT32X4 s3 = MS_LDQ_F32(src_ptr + 3 * src_step); + MS_FLOAT32X4 s4 = MS_LDQ_F32(src_ptr + 4 * src_step); + MS_FLOAT32X4 s5 = MS_LDQ_F32(src_ptr + 5 * src_step); + MS_FLOAT32X4 s6 = MS_LDQ_F32(src_ptr + 6 * src_step); + MS_FLOAT32X4 s7 = MS_LDQ_F32(src_ptr + 7 * src_step); + + MS_FLOAT32X4 m0 = MS_SUBQ_F32( + MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s0, 0.5625), MS_MULQ_N_F32(s2, 3.0625)), MS_MULQ_N_F32(s4, 3.5)), s6); + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(s1, 1.125), MS_MULQ_N_F32(s5, 0.5)); + MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(s2, 2.25), MS_MULQ_N_F32(s4, 3.25)); + MS_FLOAT32X4 m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(s3, 1.625)), s6); + MS_FLOAT32X4 m2 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(s3, 1.625)), s6); + tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(s1, 0.5625), s5); + tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(s2, 0.5625), MS_MULQ_N_F32(s4, 2.5)); + MS_FLOAT32X4 m3 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(s3, 2.5)), s6); + MS_FLOAT32X4 m4 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(s3, 2.5)), s6); + tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(s1, 0.375), MS_MULQ_N_F32(s5, 1.5)); + tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(s2, 0.25), MS_MULQ_N_F32(s4, 1.25)); + MS_FLOAT32X4 m5 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(s3, 1.875)), s6); + MS_FLOAT32X4 m6 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(s3, 1.875)), s6); + MS_FLOAT32X4 m7 = MS_ADDQ_F32( + MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(s1, -0.5625), MS_MULQ_N_F32(s3, 3.0625)), MS_MULQ_N_F32(s5, 3.5)), s7); + + MS_STQ_F32(dst_ptr + 0 * dst_step, m0); + MS_STQ_F32(dst_ptr + 1 * dst_step, m1); + MS_STQ_F32(dst_ptr + 2 * dst_step, m2); + MS_STQ_F32(dst_ptr + 3 * dst_step, m3); + MS_STQ_F32(dst_ptr + 4 * dst_step, m4); + MS_STQ_F32(dst_ptr + 5 * dst_step, m5); + MS_STQ_F32(dst_ptr + 6 * dst_step, m6); + MS_STQ_F32(dst_ptr + 7 * dst_step, m7); + } +#else + float src[8]; + float m[8]; + for (int i = 0; i < C4NUM; ++i) { + for (int l = 0; l < 8; ++l) { + for (int w = 0; w < 8; ++w) { + int tmp_index = l * 8 + w; + src[w] = src_data[i + tmp_index * src_step]; + } + m[0] = 0.5625f * src[0] - 3.0625f * src[2] + 3.5f * src[4] - src[6]; + float tmp1 = 1.125f * src[1] + 0.5f * src[5]; + float tmp2 = 2.25f * src[2] - 3.25f * src[4]; + m[1] = tmp1 + tmp2 - 1.625f * src[3] + src[6]; + m[2] = tmp2 - tmp1 + 1.625f * src[3] + src[6]; + tmp1 = 0.5625f * src[1] + src[5]; + tmp2 = 0.5625f * src[2] - 2.5f * src[4]; + m[3] = tmp1 + tmp2 - 2.5f * src[3] + src[6]; + m[4] = tmp2 - tmp1 + 2.5f * src[3] + src[6]; + tmp1 = 0.375f * src[1] + 1.5f * src[5]; + tmp2 = 0.25f * src[2] - 1.25f * src[4]; + m[5] = tmp1 + tmp2 - 1.875f * src[3] + src[6]; + m[6] = tmp2 - tmp1 + 1.875f * src[3] + src[6]; + m[7] = -0.5625f * src[1] + 3.0625f * src[3] - 3.5f * src[5] + src[7]; + + float *dst = dst_data + l * dst_row_step; + for (int w = 0; w < 8; ++w) { + dst[i + w * dst_step] = m[w]; + } + } + } +#endif +} + +#ifdef ENABLE_ARM64 +void InputTransform8x8Pack12Channel(float *src_ptr, float *dst_ptr, int dst_step, int pack_tile, int src_point_stride) { + LOAD_LINE_DATA(0); + LOAD_LINE_DATA(1); + LOAD_LINE_DATA(2); + LOAD_LINE_DATA(3); + LOAD_LINE_DATA(4); + LOAD_LINE_DATA(5); + LOAD_LINE_DATA(6); + LOAD_LINE_DATA(7); + + MS_FLOAT32X4 m0 = MS_SUBQ_F32( + MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s00, 0.5625), MS_MULQ_N_F32(s20, 3.0625)), MS_MULQ_N_F32(s40, 3.5)), s60); + MS_FLOAT32X4 m1 = MS_SUBQ_F32( + MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s01, 0.5625), MS_MULQ_N_F32(s21, 3.0625)), MS_MULQ_N_F32(s41, 3.5)), s61); + MS_FLOAT32X4 m2 = MS_SUBQ_F32( + MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s02, 0.5625), MS_MULQ_N_F32(s22, 3.0625)), MS_MULQ_N_F32(s42, 3.5)), s62); + MS_STQ_F32(dst_ptr + 0 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 0 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 0 * dst_step + 2 * pack_tile, m2); + + MS_FLOAT32X4 tmp10 = MS_ADDQ_F32(MS_MULQ_N_F32(s10, 1.125), MS_MULQ_N_F32(s50, 0.5)); + MS_FLOAT32X4 tmp11 = MS_ADDQ_F32(MS_MULQ_N_F32(s11, 1.125), MS_MULQ_N_F32(s51, 0.5)); + MS_FLOAT32X4 tmp12 = MS_ADDQ_F32(MS_MULQ_N_F32(s12, 1.125), MS_MULQ_N_F32(s52, 0.5)); + MS_FLOAT32X4 tmp20 = MS_SUBQ_F32(MS_MULQ_N_F32(s20, 2.25), MS_MULQ_N_F32(s40, 3.25)); + MS_FLOAT32X4 tmp21 = MS_SUBQ_F32(MS_MULQ_N_F32(s21, 2.25), MS_MULQ_N_F32(s41, 3.25)); + MS_FLOAT32X4 tmp22 = MS_SUBQ_F32(MS_MULQ_N_F32(s22, 2.25), MS_MULQ_N_F32(s42, 3.25)); + m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp10, tmp20), MS_MULQ_N_F32(s30, 1.625)), s60); + m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp11, tmp21), MS_MULQ_N_F32(s31, 1.625)), s61); + m2 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp12, tmp22), MS_MULQ_N_F32(s32, 1.625)), s62); + MS_STQ_F32(dst_ptr + 1 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 1 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 1 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp20, tmp10), MS_MULQ_N_F32(s30, 1.625)), s60); + m1 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp21, tmp11), MS_MULQ_N_F32(s31, 1.625)), s61); + m2 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp22, tmp12), MS_MULQ_N_F32(s32, 1.625)), s62); + MS_STQ_F32(dst_ptr + 2 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 2 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 2 * dst_step + 2 * pack_tile, m2); + + tmp10 = MS_ADDQ_F32(MS_MULQ_N_F32(s10, 0.5625), s50); + tmp11 = MS_ADDQ_F32(MS_MULQ_N_F32(s11, 0.5625), s51); + tmp12 = MS_ADDQ_F32(MS_MULQ_N_F32(s12, 0.5625), s52); + tmp20 = MS_SUBQ_F32(MS_MULQ_N_F32(s20, 0.5625), MS_MULQ_N_F32(s40, 2.5)); + tmp21 = MS_SUBQ_F32(MS_MULQ_N_F32(s21, 0.5625), MS_MULQ_N_F32(s41, 2.5)); + tmp22 = MS_SUBQ_F32(MS_MULQ_N_F32(s22, 0.5625), MS_MULQ_N_F32(s42, 2.5)); + m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp10, tmp20), MS_MULQ_N_F32(s30, 2.5)), s60); + m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp11, tmp21), MS_MULQ_N_F32(s31, 2.5)), s61); + m2 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp12, tmp22), MS_MULQ_N_F32(s32, 2.5)), s62); + MS_STQ_F32(dst_ptr + 3 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 3 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 3 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp20, tmp10), MS_MULQ_N_F32(s30, 2.5)), s60); + m1 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp21, tmp11), MS_MULQ_N_F32(s31, 2.5)), s61); + m2 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp22, tmp12), MS_MULQ_N_F32(s32, 2.5)), s62); + MS_STQ_F32(dst_ptr + 4 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 4 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 4 * dst_step + 2 * pack_tile, m2); + + tmp10 = MS_ADDQ_F32(MS_MULQ_N_F32(s10, 0.375), MS_MULQ_N_F32(s50, 1.5)); + tmp11 = MS_ADDQ_F32(MS_MULQ_N_F32(s11, 0.375), MS_MULQ_N_F32(s51, 1.5)); + tmp12 = MS_ADDQ_F32(MS_MULQ_N_F32(s12, 0.375), MS_MULQ_N_F32(s52, 1.5)); + tmp20 = MS_SUBQ_F32(MS_MULQ_N_F32(s20, 0.25), MS_MULQ_N_F32(s40, 1.25)); + tmp21 = MS_SUBQ_F32(MS_MULQ_N_F32(s21, 0.25), MS_MULQ_N_F32(s41, 1.25)); + tmp22 = MS_SUBQ_F32(MS_MULQ_N_F32(s22, 0.25), MS_MULQ_N_F32(s42, 1.25)); + m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp10, tmp20), MS_MULQ_N_F32(s30, 1.875)), s60); + m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp11, tmp21), MS_MULQ_N_F32(s31, 1.875)), s61); + m2 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp12, tmp22), MS_MULQ_N_F32(s32, 1.875)), s62); + MS_STQ_F32(dst_ptr + 5 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 5 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 5 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp20, tmp10), MS_MULQ_N_F32(s30, 1.875)), s60); + m1 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp21, tmp11), MS_MULQ_N_F32(s31, 1.875)), s61); + m2 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp22, tmp12), MS_MULQ_N_F32(s32, 1.875)), s62); + MS_STQ_F32(dst_ptr + 6 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 6 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 6 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32( + MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(s10, -0.5625), MS_MULQ_N_F32(s30, 3.0625)), MS_MULQ_N_F32(s50, 3.5)), s70); + m1 = MS_ADDQ_F32( + MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(s11, -0.5625), MS_MULQ_N_F32(s31, 3.0625)), MS_MULQ_N_F32(s51, 3.5)), s71); + m2 = MS_ADDQ_F32( + MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(s12, -0.5625), MS_MULQ_N_F32(s32, 3.0625)), MS_MULQ_N_F32(s52, 3.5)), s72); + MS_STQ_F32(dst_ptr + 7 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 7 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 7 * dst_step + 2 * pack_tile, m2); +} +#endif + +void InputTransform8x8Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { + int block_tile = 12; + int pack_tile = src_step; + int src_point_stride = block_tile * pack_tile; +#ifdef ENABLE_ARM64 + for (int l = 0; l < 8; ++l) { + float *src_ptr = src_data + l * C4NUM * block_tile; + TRANSPOSE_12x4; + } + + for (int c = 0; c < real_c; ++c) { + float *src_ptr = src_data + c * block_tile; + float *dst_ptr = dst_data + c * block_tile; + InputTransform8x8Pack12Channel(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride); + } +#else + for (int l = 0; l < 8; ++l) { + float *src = src_data + l * pack_tile * block_tile; + // 12 * 4 -> 4 * 12 + float tmp_mat[pack_tile][block_tile]; + for (int i = 0; i < block_tile; ++i) { + for (int j = 0; j < pack_tile; ++j) { + tmp_mat[j][i] = src[i * pack_tile + j]; + } + } + memcpy(src, tmp_mat, pack_tile * block_tile * sizeof(float)); + } + + float src[8]; + float m[8]; + for (int c = 0; c < real_c; ++c) { + for (int i = 0; i < block_tile; ++i) { + int tmp_index = c * block_tile + i; + for (int w = 0; w < 8; ++w) { + src[w] = src_data[tmp_index + w * src_point_stride]; + } + m[0] = 0.5625f * src[0] - 3.0625f * src[2] + 3.5f * src[4] - src[6]; + float tmp1 = 1.125f * src[1] + 0.5f * src[5]; + float tmp2 = 2.25f * src[2] - 3.25f * src[4]; + m[1] = tmp1 + tmp2 - 1.625f * src[3] + src[6]; + m[2] = tmp2 - tmp1 + 1.625f * src[3] + src[6]; + tmp1 = 0.5625f * src[1] + src[5]; + tmp2 = 0.5625f * src[2] - 2.5f * src[4]; + m[3] = tmp1 + tmp2 - 2.5f * src[3] + src[6]; + m[4] = tmp2 - tmp1 + 2.5f * src[3] + src[6]; + tmp1 = 0.375f * src[1] + 1.5f * src[5]; + tmp2 = 0.25f * src[2] - 1.25f * src[4]; + m[5] = tmp1 + tmp2 - 1.875f * src[3] + src[6]; + m[6] = tmp2 - tmp1 + 1.875f * src[3] + src[6]; + m[7] = -0.5625f * src[1] + 3.0625f * src[3] - 3.5f * src[5] + src[7]; + + for (int w = 0; w < 8; ++w) { + dst_data[tmp_index + w * dst_step] = m[w]; + } + } + } +#endif +} + OutputTransFunc GetOutputTransFunc(int input_unit, int output_unit, ActType act_type) { if (!CheckWinogradInputOutputUnit(input_unit, output_unit)) { return NULL; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_utils.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_utils.h index 539ba6a42df..e2ea3b067f7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_utils.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_utils.h @@ -28,9 +28,21 @@ extern "C" { #endif typedef void (*InputTransFunc)(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); +typedef void (*InputTransStepFunc)(const float *src_data, float *dst_data, int src_step, int dst_step, + int dst_row_step); + +typedef void (*InputTransPackFunc)(float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + typedef void (*OutputTransFunc)(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +typedef struct TransFuncList { + InputTransFunc in_func_; + InputTransStepFunc in_step_func_; + InputTransPackFunc in_pack_func_; + OutputTransFunc out_func_; +} TransFuncList; + #define Load16Data \ src[0] = MS_LDQ_F32(src_data + 0 * src_step); \ src[1] = MS_LDQ_F32(src_data + 1 * src_step); \ @@ -153,14 +165,66 @@ typedef void (*OutputTransFunc)(const float *src_data, float *dst_data, const fl src[62] = MS_LDQ_F32(src_data + 62 * src_step); \ src[63] = MS_LDQ_F32(src_data + 63 * src_step); +#define LOAD_LINE_DATA(line) \ + MS_FLOAT32X4 s##line##0 = MS_LDQ_F32(src_ptr + line * src_point_stride + 0 * pack_tile); \ + MS_FLOAT32X4 s##line##1 = MS_LDQ_F32(src_ptr + line * src_point_stride + 1 * pack_tile); \ + MS_FLOAT32X4 s##line##2 = MS_LDQ_F32(src_ptr + line * src_point_stride + 2 * pack_tile); + +#define TRANSPOSE_12x4 \ + MS_FLOAT32X4 s0 = MS_LDQ_F32(src_ptr + 0 * pack_tile); \ + MS_FLOAT32X4 s3 = MS_LDQ_F32(src_ptr + 1 * pack_tile); \ + MS_FLOAT32X4 s6 = MS_LDQ_F32(src_ptr + 2 * pack_tile); \ + MS_FLOAT32X4 s9 = MS_LDQ_F32(src_ptr + 3 * pack_tile); \ + MS_FLOAT32X4 s1 = MS_LDQ_F32(src_ptr + 4 * pack_tile); \ + MS_FLOAT32X4 s4 = MS_LDQ_F32(src_ptr + 5 * pack_tile); \ + MS_FLOAT32X4 s7 = MS_LDQ_F32(src_ptr + 6 * pack_tile); \ + MS_FLOAT32X4 s10 = MS_LDQ_F32(src_ptr + 7 * pack_tile); \ + MS_FLOAT32X4 s2 = MS_LDQ_F32(src_ptr + 8 * pack_tile); \ + MS_FLOAT32X4 s5 = MS_LDQ_F32(src_ptr + 9 * pack_tile); \ + MS_FLOAT32X4 s8 = MS_LDQ_F32(src_ptr + 10 * pack_tile); \ + MS_FLOAT32X4 s11 = MS_LDQ_F32(src_ptr + 11 * pack_tile); \ + transpose4(&s0, &s3, &s6, &s9); \ + transpose4(&s1, &s4, &s7, &s10); \ + transpose4(&s2, &s5, &s8, &s11); \ + MS_STQ_F32(src_ptr + 0 * pack_tile, s0); \ + MS_STQ_F32(src_ptr + 1 * pack_tile, s1); \ + MS_STQ_F32(src_ptr + 2 * pack_tile, s2); \ + MS_STQ_F32(src_ptr + 3 * pack_tile, s3); \ + MS_STQ_F32(src_ptr + 4 * pack_tile, s4); \ + MS_STQ_F32(src_ptr + 5 * pack_tile, s5); \ + MS_STQ_F32(src_ptr + 6 * pack_tile, s6); \ + MS_STQ_F32(src_ptr + 7 * pack_tile, s7); \ + MS_STQ_F32(src_ptr + 8 * pack_tile, s8); \ + MS_STQ_F32(src_ptr + 9 * pack_tile, s9); \ + MS_STQ_F32(src_ptr + 10 * pack_tile, s10); \ + MS_STQ_F32(src_ptr + 11 * pack_tile, s11); + InputTransFunc GetInputTransFunc(int input_unit); +#ifdef ENABLE_ARM64 +InputTransStepFunc GetInputTransStepFunc(int input_unit); + +InputTransPackFunc GetInputTransPackFunc(int input_unit); +#endif + void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); +void InputTransform4x4Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step); + +void InputTransform4x4Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); +void InputTransform6x6Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step); + +void InputTransform6x6Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); +void InputTransform8x8Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step); + +void InputTransform8x8Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + OutputTransFunc GetOutputTransFunc(int input_unit, int output_unit, ActType act_type); #define Store4Data \ diff --git a/mindspore/lite/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc b/mindspore/lite/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc index 9bdb1bd1712..4116d3d0811 100644 --- a/mindspore/lite/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc +++ b/mindspore/lite/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc @@ -172,10 +172,10 @@ int ConvolutionWinogradFP32Coder::InitWeightBias() { } int ConvolutionWinogradFP32Coder::ConfigInputOutput() { - in_func_ = GetInputTransFunc(input_unit_); - MS_CHECK_TRUE(!in_func_.empty(), "Get input_trans_func failed."); - out_func_ = GetOutputTransFunc(input_unit_, output_unit_, conv_param_->act_type_); - MS_CHECK_TRUE(!out_func_.empty(), "Get output_trans_func_ failed."); + trans_func_str_.in_func_ = GetInputTransFunc(input_unit_); + MS_CHECK_TRUE(!trans_func_str_.in_func_.empty(), "Get input_trans_func failed."); + trans_func_str_.out_func_ = GetOutputTransFunc(input_unit_, output_unit_, conv_param_->act_type_); + MS_CHECK_TRUE(!trans_func_str_.out_func_.empty(), "Get output_trans_func_ failed."); return RET_OK; } @@ -269,9 +269,10 @@ int ConvolutionWinogradFP32Coder::DoCode(CoderContext *const context) { << allocator_->GetRuntimeAddr(gemm_out_) << ", " << allocator_->GetRuntimeAddr(tmp_data_) << ", " << allocator_->GetRuntimeAddr(col_buffer_) << "};\n"; code.CodeStruct("conv_parameter", *conv_param_); + code.CodeStruct("trans_func", trans_func_str_); // code operator func code.CodeFunction("ConvWinogardFp32", input_tensor_, trans_weight_, new_bias_, output_tensor_, - "tmp_buffer_address_list", kDefaultTaskId, "&conv_parameter", in_func_, out_func_); + "tmp_buffer_address_list", kDefaultTaskId, "&conv_parameter", "trans_func"); context->AppendCode(code.str()); return RET_OK; } diff --git a/mindspore/lite/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h b/mindspore/lite/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h index 1e236ae6221..8760adec3ec 100644 --- a/mindspore/lite/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h +++ b/mindspore/lite/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h @@ -22,6 +22,7 @@ #include #include "coder/opcoders/base/conv2d_base_coder.h" #include "nnacl/conv_parameter.h" +#include "wrapper/fp32/conv_winograd_fp32_wrapper.h" namespace mindspore::lite::micro::nnacl { class ConvolutionWinogradFP32Coder : public Conv2DBaseCoder { @@ -68,8 +69,7 @@ class ConvolutionWinogradFP32Coder : public Conv2DBaseCoder { float *gemm_out_{nullptr}; float *col_buffer_{nullptr}; - std::string in_func_; - std::string out_func_; + TransFuncStr trans_func_str_; }; } // namespace mindspore::lite::micro::nnacl #endif // MINDSPORE_LITE_MICRO_CODER_OPCODERS_FP32_CONVOLUTION_WINOGRAD_FP32_CODER_H_ diff --git a/mindspore/lite/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.cc b/mindspore/lite/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.cc index df53773ebfb..15d0a0f5b6e 100644 --- a/mindspore/lite/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.cc +++ b/mindspore/lite/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.cc @@ -157,4 +157,8 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const SpliceWrappe splice_param.src_to_dst_row_offset); } +void NNaclFp32Serializer::CodeStruct(const std::string &name, const TransFuncStr trans_func_str) { + CodeBaseStruct("TransFuncList", name, trans_func_str.in_func_, nullptr, nullptr, trans_func_str.out_func_); +} + } // namespace mindspore::lite::micro::nnacl diff --git a/mindspore/lite/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h b/mindspore/lite/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h index 0d1efe2951a..4c894dfb2b0 100644 --- a/mindspore/lite/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h +++ b/mindspore/lite/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h @@ -36,6 +36,7 @@ #include "nnacl/fp32/strided_slice_fp32.h" #include "wrapper/fp32/arithmetic_fp32_wrapper.h" #include "wrapper/base/affine_wrapper.h" +#include "wrapper/fp32/conv_winograd_fp32_wrapper.h" namespace mindspore::lite::micro::nnacl { @@ -60,6 +61,7 @@ class NNaclFp32Serializer : public Serializer { void CodeStruct(const std::string &name, const StridedSliceParameter &strided_slice_parameter); void CodeStruct(const std::string &name, const ArithmeticWrapperInfo &arithmetic_wrapper_info); void CodeStruct(const std::string &name, const SpliceWrapperParam &splice_param); + void CodeStruct(const std::string &name, const TransFuncStr trans_func_str); }; } // namespace mindspore::lite::micro::nnacl diff --git a/mindspore/lite/micro/coder/wrapper/fp32/conv_winograd_fp32_wrapper.h b/mindspore/lite/micro/coder/wrapper/fp32/conv_winograd_fp32_wrapper.h new file mode 100644 index 00000000000..b5d6dae7ab6 --- /dev/null +++ b/mindspore/lite/micro/coder/wrapper/fp32/conv_winograd_fp32_wrapper.h @@ -0,0 +1,30 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_MICRO_CODER_WRAPPER_FP32_CONV_WINOGRAD_FP32_WRAPPER_H_ +#define MINDSPORE_LITE_MICRO_CODER_WRAPPER_FP32_CONV_WINOGRAD_FP32_WRAPPER_H_ +#include +#ifdef __cplusplus +extern "C" { +#endif +typedef struct TransFuncStr { + std::string in_func_; + std::string out_func_; +} TransFuncStr; + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_LITE_MICRO_CODER_WRAPPER_FP32_CONV_WINOGRAD_FP32_WRAPPER_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc index 91728d2bcc8..861f263dfb0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc @@ -119,21 +119,40 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() { return RET_ERROR; } + opt_input_trans_ = reinterpret_cast( + ctx_->allocator->Malloc(thread_count_ * row_tile_ * input_unit_ * input_unit_ * + UP_ROUND(conv_param_->input_channel_, C8NUM) * sizeof(float16_t))); + if (opt_input_trans_ == nullptr) { + MS_LOG(ERROR) << "malloc opt_input_trans_ failed."; + return RET_ERROR; + } + tmp_buffer_address_list_[0] = trans_input_; tmp_buffer_address_list_[1] = gemm_out_; tmp_buffer_address_list_[2] = tmp_data_; tmp_buffer_address_list_[3] = col_buffer_; + tmp_buffer_address_list_[4] = opt_input_trans_; return RET_OK; } int ConvolutionWinogradFP16CPUKernel::ConfigInputOutput() { - in_func_ = GetInputTransFp16Func(input_unit_); - if (in_func_ == nullptr) { + trans_func_.in_func_ = GetInputTransFp16Func(input_unit_); + if (trans_func_.in_func_ == nullptr) { MS_LOG(ERROR) << "in_func_ is null."; return RET_ERROR; } - out_func_ = GetOutputTransFp16Func(input_unit_, output_unit_, conv_param_->act_type_); - if (out_func_ == nullptr) { +#ifdef ENABLE_ARM64 + trans_func_.in_step_func_ = GetInputTransStepFp16Func(input_unit_); + if (trans_func_.in_step_func_ == nullptr) { + MS_LOG(DEBUG) << "in_step_func_ is null."; + } + trans_func_.in_pack_func_ = GetInputTransPackFp16Func(input_unit_); + if (trans_func_.in_pack_func_ == nullptr) { + MS_LOG(DEBUG) << "in_pack_func_ is null."; + } +#endif + trans_func_.out_func_ = GetOutputTransFp16Func(input_unit_, output_unit_, conv_param_->act_type_); + if (trans_func_.out_func_ == nullptr) { MS_LOG(ERROR) << "out_func_ is null."; return RET_ERROR; } @@ -219,7 +238,7 @@ int ConvolutionWinogradFP16CPUKernel::RunImpl(int task_id) { } ConvWinogardFp16(input_ptr, reinterpret_cast(packed_weight_), reinterpret_cast(bias_data_), output_ptr, tmp_buffer_address_list_, task_id, - conv_param_, in_func_, out_func_); + conv_param_, trans_func_); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h index d6e967c06b2..9601c683597 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h @@ -65,6 +65,10 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseCPUKernel { ctx_->allocator->Free(col_buffer_); col_buffer_ = nullptr; } + if (opt_input_trans_ != nullptr) { + ctx_->allocator->Free(opt_input_trans_); + opt_input_trans_ = nullptr; + } } int FilterWeight(); int kernel_unit_ = 0; @@ -74,11 +78,11 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseCPUKernel { float16_t *trans_input_ = nullptr; float16_t *gemm_out_ = nullptr; float16_t *col_buffer_ = nullptr; + float16_t *opt_input_trans_ = nullptr; float matrix_g_[64]; float matrix_gt_[64]; - TmpBufferAddressFp16 tmp_buffer_address_list_[4] = {0}; - InputTransFp16Func in_func_ = nullptr; - OutputTransFp16Func out_func_ = nullptr; + TmpBufferAddressFp16 tmp_buffer_address_list_[5] = {0}; + TransFp16FuncList trans_func_; int col_tile_ = 0; int row_tile_ = 0; }; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.cc index 340ab47d999..0124981326c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.cc @@ -69,21 +69,40 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() { return RET_ERROR; } + opt_input_trans_ = reinterpret_cast( + ctx_->allocator->Malloc(thread_count_ * tile_num_ * input_unit_ * input_unit_ * + UP_ROUND(conv_param_->input_channel_, tmp_data_tile_) * sizeof(float))); + if (opt_input_trans_ == nullptr) { + MS_LOG(ERROR) << "malloc opt_input_trans_ failed."; + return RET_ERROR; + } + tmp_buffer_address_list_[0] = trans_input_; tmp_buffer_address_list_[1] = gemm_out_; tmp_buffer_address_list_[2] = tmp_data_; tmp_buffer_address_list_[3] = col_buffer_; + tmp_buffer_address_list_[4] = opt_input_trans_; return RET_OK; } int ConvolutionWinogradCPUKernel::ConfigInputOutput() { - in_func_ = GetInputTransFunc(input_unit_); - if (in_func_ == nullptr) { + trans_func_.in_func_ = GetInputTransFunc(input_unit_); + if (trans_func_.in_func_ == nullptr) { MS_LOG(ERROR) << "in_func_ is null."; return RET_ERROR; } - out_func_ = GetOutputTransFunc(input_unit_, output_unit_, conv_param_->act_type_); - if (out_func_ == nullptr) { +#ifdef ENABLE_ARM64 + trans_func_.in_step_func_ = GetInputTransStepFunc(input_unit_); + if (trans_func_.in_step_func_ == nullptr) { + MS_LOG(DEBUG) << "in_step_func_ is null."; + } + trans_func_.in_pack_func_ = GetInputTransPackFunc(input_unit_); + if (trans_func_.in_pack_func_ == nullptr) { + MS_LOG(DEBUG) << "in_pack_func_ is null."; + } +#endif + trans_func_.out_func_ = GetOutputTransFunc(input_unit_, output_unit_, conv_param_->act_type_); + if (trans_func_.out_func_ == nullptr) { MS_LOG(ERROR) << "out_func_ is null."; return RET_ERROR; } @@ -152,7 +171,7 @@ int ConvolutionWinogradCPUKernel::RunImpl(int task_id) { CHECK_NULL_RETURN(output_data); ConvWinogardFp32(ori_input_data, reinterpret_cast(packed_weight_), reinterpret_cast(bias_data_), output_data, tmp_buffer_address_list_, task_id, - conv_param_, in_func_, out_func_); + conv_param_, trans_func_); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.h index 00e71b7682c..88edc4852c6 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.h @@ -62,6 +62,10 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel { ctx_->allocator->Free(col_buffer_); col_buffer_ = nullptr; } + if (opt_input_trans_ != nullptr) { + ctx_->allocator->Free(opt_input_trans_); + opt_input_trans_ = nullptr; + } } int kernel_unit_{0}; int input_unit_{0}; @@ -73,11 +77,11 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel { float *trans_input_ = nullptr; float *gemm_out_ = nullptr; float *col_buffer_ = nullptr; + float *opt_input_trans_ = nullptr; float matrix_g_[64]; float matrix_gt_[64]; - TmpBufferAddress tmp_buffer_address_list_[4] = {nullptr}; - InputTransFunc in_func_ = nullptr; - OutputTransFunc out_func_ = nullptr; + TmpBufferAddress tmp_buffer_address_list_[5] = {nullptr}; + TransFuncList trans_func_; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/test/config/models_caffe_fp16.cfg b/mindspore/lite/test/config/models_caffe_fp16.cfg index 22ebf6c3222..58031dd7bb7 100644 --- a/mindspore/lite/test/config/models_caffe_fp16.cfg +++ b/mindspore/lite/test/config/models_caffe_fp16.cfg @@ -3,13 +3,13 @@ # [second column]:accuracy limit for float16 in arm64 device hdc_age_medium 5.9 beard 2 -emotion 60 -gender_res_large_deploy 0.1 +emotion 216 +gender_res_large_deploy 2 glasses 4 hat 2.5 ml_bank_detect_0312_tmp 20 ml_face_div_parsing 8 -ml_hardware_eyeclose 0.1 +ml_hardware_eyeclose 0.5 ml_ocr_detect_20200305 10 Mnet6_0312_extract_pay 15 pose_3d 90 @@ -37,7 +37,7 @@ ml_ocr_sfz_add_final_0325 0.1 ml_hardware_pose 2 ml_bank_recog 0.1 2012_ATLANTA_10class_20190131_v4.0 12 -mnet 12 +mnet 13 recognition 10.8 ml_face_landmark 1 model_hebing_3branch 40 @@ -71,31 +71,31 @@ ml_location_scene_division 8 ml_tabel_recog 0.1 ml_text_division 12 # Further analysis in the future to model ml_video_edit_Mnet -ml_video_edit_Mnet 11.5 +ml_video_edit_Mnet 15.5 ml_video_edit_hairSeg_have_imageProcessLayer_interpTo145 0.5 hdc_contour_pose_128 0.5 hdc_emotion 0.5 hdc_fivembnet 1 hdc_isface 0.5 hdc_mobilenetface 11.5 # small output causes big bias -hdc_retinaface 14 +hdc_retinaface 15 hdc_resnet 7 -ml_video_edit_detect_20211111 2.5 +ml_video_edit_detect_20211111 3 ml_video_edit_hairSeg_have_imageProcessLayer_interpTo145_20210121 0.5 ml_video_edit_have_imageProcessLayer_interpTo145_20201015 0.5 ml_video_edit_MnetN367_extract_1010_pay 1 ml_video_edit_person_divison_pic 0.5 ml_video_edit_reid 1 -ml_video_edit_v10_best_model_nomean_20200723 5.1 +ml_video_edit_v10_best_model_nomean_20200723 6 ml_video_edit_img_segment 3 ml_video_edit_video_segment_gauss_adaptis_part1 5 # When the input range is [-1,1], the precision is poor, and the output value is very small (10e-5). If the input range is adjusted to [0,255], the precision will decrease to 15.5415%, and the rest is cumulative error. -ml_handpose 175 +ml_handpose 177 hdc_Face_Aesthetic_MTI_Aesthetic 0.5 ml_face_compare 8.7 ml_face_tracking 2.5 -ml_face_beard 0.6 -ml_face_age 3.7 +ml_face_beard 1 +ml_face_age 4 ml_face_pose 1 ml_face_isface 0.5 ml_face_glasses 3.4 @@ -108,13 +108,13 @@ ml_Hand_deploy 4 ml_hand_3d_detection 12 ml_hand_3d_regression 5.4 # ml_ARengine23_bodypose: The difference of output node divided by a very small value leads to a large error -ml_ARengine23_bodypose 56 +ml_ARengine23_bodypose 57 ml_ocr_bank_card_detection_inception_tmp 20 ml_ocr_bank_card_recognition_fcny 0.5 -hiai_cv_aestheticsEngineModel_osp 1.6 +hiai_cv_aestheticsEngineModel_osp 3.5 ml_face_hat 2.2 -bank_card_recognition_fcny 17 -bank_card_detection_inception_tmp 12 +bank_card_recognition_fcny 19 +bank_card_detection_inception_tmp 13.5 ml_ocr_identify_card_fcny 0.5 ml_ocr_identify_card_detect_tmp 2 identify_card_detect_tmp 0.5 @@ -123,18 +123,18 @@ ml_2012_ocr_rec_caffe 0.5 ml_lable_model_hebing_device 3 ml_face_sex 0.7 # ml_face_mnet: The precision problem caused by cumulative error. -ml_face_mnet 12 +ml_face_mnet 13 ml_segmentation_atlanta_1 0.5 bolt_deploy_color-server 0.5 ml_face_emotion 0.5 hdc_ocr_recog_horizontal 0.5 # The outputs of two Heatmap_depth models have small value -ml_Heatmap_depth_240180;2 10 +ml_Heatmap_depth_240180;2 14.5 ml_Heatmap_depth_180240;2 7 ml_video_edit_hair_dyeing_segmodel_v3 0.5 ml_video_edit_hairline_segmentation;3 1.5 ml_video_edit_seg_320 0.5 hiai_machine_vision_jfr_newmodel_2730_houduan_yolo 5 hiai_machine_vision_mobileNet101_nosoftce_mobilenet_resnet 7.5 -ml_video_edit_person_divison_video;2 38 +ml_video_edit_person_divison_video;2 42 ml_video_edit_hair_dyeing_segmodel_20211119 0.5 diff --git a/mindspore/lite/test/config/models_onnx_fp16.cfg b/mindspore/lite/test/config/models_onnx_fp16.cfg index 27114668649..399d1321208 100644 --- a/mindspore/lite/test/config/models_onnx_fp16.cfg +++ b/mindspore/lite/test/config/models_onnx_fp16.cfg @@ -3,7 +3,7 @@ # [second column]:accuracy limit for float16 in arm64 device mtk_detect-mbv2-shortcut-400-400-simplified.onnx 4 mtk_face_features_v3.onnx 20 -emotion-ferplus-8.onnx 1 +emotion-ferplus-8.onnx 1.5 #rcnn-ilsvrc13-9.onnx 0.1 efficientnet-lite4-11.onnx 2 mobilenetv2-7.onnx 8 @@ -27,7 +27,7 @@ mnist-8.onnx 10 crnn_lite_lstm_v2.onnx;1;32,32,32,1 0.3 #psenet_lite_mbv2.onnx;1;1,32,32,3 0.6 #occasionally aborted -super-resolution-10.onnx;1;1,224,224,1 4.5 +super-resolution-10.onnx;1;1,224,224,1 5 tinyyolov2-8.onnx;1;1,416,416,3 5.5 #ml_2012_ocr_cn.onnx -1 #ml_2012_ocr_cn_noLSTM.onnx 1 @@ -52,10 +52,10 @@ ml_video_edit_style_transfer_autoportrait.onnx 2 ml_video_edit_style_transfer_candy.onnx 2 ml_video_edit_style_transfer_gongnongbing.onnx 2 ml_video_edit_style_transfer_starry.onnx 2 -hdc_Face_Landmark5_MTI_Aesthetic.onnx 0.5 +hdc_Face_Landmark5_MTI_Aesthetic.onnx 1 hdc_Image_Aesthetic_MTI_Aesthetic.onnx 0.5 hdc_resnet_1w_class.onnx 6 -gts_text_detection.onnx;1;1,224,224,3 10 +gts_text_detection.onnx;1;1,224,224,3 11 hdc_Face_Emotion_MTI_Aesthetic.onnx 144 ml_video_edit_imitate_filter.onnx 120 ml_facedetector.onnx 6 @@ -71,7 +71,7 @@ mtk_emotions-d2012-75.onnx 6 mtk_detect-mbv1-shortcut-400-400.onnx 0.5 mtk_detect-mbv2-shortcut-400-400.onnx 0.5 mtk_detect_mbv1_640_480.onnx 0.5 -mtk_detect-deeper-halfdeeper-mbv1-shortcut-400-400_nopostprocess_simplified_onnx.onnx 2 +mtk_detect-deeper-halfdeeper-mbv1-shortcut-400-400_nopostprocess_simplified_onnx.onnx 2.5 mtk_detect-mbv1-shortcut-400-400_nopostprocess_simplified_onnx.onnx 6.5 mtk_detect-deeper-halfdeeper-mbv1-lastearlySSD-shortcut-400-400_nopostprocess_simplified_onnx.onnx 2.5 mtk_detect_mbv1_640_480_nopostprocess_simplified_onnx.onnx;1;1,480,640,3 2 @@ -87,7 +87,7 @@ Q888_iris_detect.onnx 0.5 ssd_mobilenet_v1_10.onnx;1;1,383,640,3 0.5 # The output from a conv in the later part contains many minus values, the following leakyRelu makes them become very # close to 0 (-e^-4). The fp16 precision lost a lot in this case and it affects the following computation. -Harmony_Voiceprint.onnx;1;1,200,40,1 21.5 # small output causes big bias +Harmony_Voiceprint.onnx;1;1,200,40,1 68 # small output causes big bias # A matmul op in the later part produces overflowed output values (>65504). #ml_video_edit_art_generate_20210513.onnx nan # bn_fusion causes a big bias(maybe random), need to debug later: The original bias is 2.1 @@ -104,11 +104,11 @@ ml_video_edit_makeup_mobilenetv203.onnx 4 # The input of ml_video_edit_hair_dyeing_migrate_v2.onnx should be between [0, 1] ml_video_edit_hair_dyeing_migrate_v2.onnx;4 2.5 Q888_CV_face_recognition_self.onnx 3.9 -ml_video_edit_hair_dyeing_migrate_v2_fix.onnx;4 3 +ml_video_edit_hair_dyeing_migrate_v2_fix.onnx;4 3.5 ml_intelligent_cockpit_model.onnx;3;1,32:1,32:1,32 3.8 CloudBU_FSRCNN_RTC_8ch_3450_QP9.onnx;1;1,225,225,3 1.5 CloudBU_rfdn_rtc_x2_ver2_13.onnx;1;1,225,225,3 1.5 -CloudBU_rfdn_rtc_x2_ver2_3450.onnx;1;1,225,225,3 3.0 +CloudBU_rfdn_rtc_x2_ver2_3450.onnx;1;1,225,225,3 4 ml_motion_capture_nanodet_m_0.5x_people_0928_sim.onnx 8 ml_motion_capture_smpl_0916.onnx;3 ml_motion_capture_spin_mobile_mv3_v3_57mm_sim.onnx;5 18 @@ -116,7 +116,7 @@ ml_video_edit_dimming_tech_model_345000_color.onnx;2 2 Ireland_ulfgf.onnx;1;1,240,320,3 Ireland_gaze_corrector.onnx;3 15 Ireland_face_detector.onnx 2 -Ireland_gaze_estimator_ng.onnx 6 +Ireland_gaze_estimator_ng.onnx 8 carbu_intelligent_cockpit_fasttext_best.onnx 0.5 ml_video_edit_shot_selection_yolox_nano_coco_reduced.onnx 3 ml_video_edit_shot_selection_face_emotion.onnx 0.7 diff --git a/mindspore/lite/test/config/models_tf_fp16.cfg b/mindspore/lite/test/config/models_tf_fp16.cfg index 2259c6dcfcd..dc499b6b569 100644 --- a/mindspore/lite/test/config/models_tf_fp16.cfg +++ b/mindspore/lite/test/config/models_tf_fp16.cfg @@ -18,7 +18,7 @@ hiai_ssd_mobilenetv2_object.pb 15 hiai_humanDetection.pb 3.5 hiai_PoseEstimation_Pcm.pb 0.5 # The last layer has a very small value, which leads to a large error -hiai_cn_recognize_modify_padv2.pb;1;1,32,512,1 27 +hiai_cn_recognize_modify_padv2.pb;1;1,32,512,1 37 hiai_model_normalize_object_scene_ps_20200519.pb;1;1,224,224,3 17.1 # The output of mtk_model_ckpt.pb has small value mtk_model_ckpt.pb 19.5 @@ -33,14 +33,14 @@ mtk_face_features_v1.pb 26 model_normalize_object_scene_ps_20200519.pb;1;1,224,224,3 10 hiai_AADB_HADB_MBV2_model.pb;1;1,224,224,3 6 hiai_frozen_inference_graph.pb 12 -hiai_lm_inference_graph.pb 1.2 +hiai_lm_inference_graph.pb 1.5 hiai_ghostnet.pb 0.9 hiai_face_model_npu.pb 0.5 -hiai_cv_focusShootOCRModel_02.pb 10.5 +hiai_cv_focusShootOCRModel_02.pb 12.5 hiai_label_and_video.pb;1;1,224,224,3 23 hiai_dress_detect.pb;1;1,960,960,3 1.5 hiai_iMaxDN_RGB.pb 0.5 -hiai_iMaxSR_RGB.pb 3.5 +hiai_iMaxSR_RGB.pb 5 hiai_ctpn_feature_map.pb 6.5 hiai_cpu_face_gazing.pb 0.5 hiai_cpu_face_emotion.pb 2.2 @@ -49,7 +49,7 @@ Q_dila-small-mix-full-fineturn-390000-nopixel-nosigmoid.pb 1.5 # The input of Q_crnn_ori_75w_slim model is between 0-255, but its outputs has small values (e-6). Q_crnn_ori_75w_slim_norm.pb 37 # The output of Q_crnn_ori_v2 model has small values (e-4). -Q_crnn_ori_v2_405001_notrans_nopre.pb 24 +Q_crnn_ori_v2_405001_notrans_nopre.pb 33 # The input of hiai_latin models are between 0-255 hiai_latin_ocr.pb 4 hiai_latin_ocr_1.pb 3.5 @@ -68,7 +68,7 @@ ml_vision_guide_detection2.pb;1;1,320,320,1 1 ml_tts_encoder.pb;4;1,44:1:1:1 9 # encoder_0111_control_flow.pb is same as ml_tts_encoder_control_flow.pb #encoder_0111_control_flow.pb;4;1:1,44:1:1 10 -ml_video_edit_video_segment_gauss_adaptis_part2.pb;2 12.1 +ml_video_edit_video_segment_gauss_adaptis_part2.pb;2 16 ml_video_edit_img_segment_adaptise.pb;2 40 ml_video_edit_oneclick_adaptis.pb;3 6 #decoder_step_201217.pb is the same model as ml_tts_decoder.pb. diff --git a/mindspore/lite/test/config/models_tflite_fp16.cfg b/mindspore/lite/test/config/models_tflite_fp16.cfg index 99674c0b299..0cb8a1dbafe 100644 --- a/mindspore/lite/test/config/models_tflite_fp16.cfg +++ b/mindspore/lite/test/config/models_tflite_fp16.cfg @@ -2,7 +2,7 @@ # content after ";" can be omitted. # [second column]:accuracy limit for float16 in arm64 device hiai_model_0909_kd_rot_ps_softmax.tflite 10 -hiai_chinese_english_recognize_model_float32.tflite 13 +hiai_chinese_english_recognize_model_float32.tflite 13.5 hiai_bigmodel_ghost_2_1_no_normalized_no_trans_tflite.tflite 10 hiai_bigmodel_ghost_5_1_no_normalized_no_trans_tflite.tflite 10 hiai_cn_recognize_modify_padv2.tflite 14 @@ -64,7 +64,7 @@ inception_resnet_v2.tflite 10 ml_ocr_latin.tflite 15 hiai_PoseEstimation_Pcm.tflite 15 hiai_ssd_mobilenetv2_object.tflite 60 -hiai_cv_focusShootOCRModel_02.tflite 13 +hiai_cv_focusShootOCRModel_02.tflite 13.5 hiai_cv_poseEstimation.tflite 190 inception_v4.tflite 10 mtk_model_normalize_object_scene_ps_20200519_f16.tflite 10 @@ -129,8 +129,8 @@ mtk_pose.tflite 2 mtk_model_emotions_0727_nosoftmax.tflite 2 mtk_model_normalize_object_scene_ps_20200826_f32_no_softmax.tflite 22 mtk_276landmark_0913.tflite 16 -mtk_face_recognition.tflite 8 -mtk_convert_model.tflite 5.3 +mtk_face_recognition.tflite 11 +mtk_convert_model.tflite 7.5 smartreply.tflite 0.1 mindspore_text_classification_tflite.tflite 9.2 # small output causes big bias #ml_location.tflite 0.1 @@ -176,7 +176,7 @@ Q_convert.tflite 12 # the input of Q_crnn_ori_75w_slim model is between 0-255, but its outputs has small values (e-6). Q_crnn_ori_75w_slim_norm_pb2tflite.tflite 29 # the output of Q_crnn_ori_v2 model has small values (e-4). -Q_crnn_ori_v2_405001_notrans_nopre_pb2tflite.tflite 36 +Q_crnn_ori_v2_405001_notrans_nopre_pb2tflite.tflite 42 # the inputs of two Q_crnn_screen_slim400w models are between 0-255, but their outputs have small values (e-7). Q_crnn_screen_slim400w_more_20w_pb2tflite.tflite 71 Q_dila-small-mix-full-fineturn-390000-nopixel-nosigmoid_tflite.tflite 1.5 @@ -202,14 +202,14 @@ Q888_model_normalize_object_scene_ps_20200826_f32_no_softmax.tflite 2 # input data: -1~1 Q888_face_emo_dress_mv3_orderd.tflite 2.5 Q_iMaxDN_RGB_385_p_RGB_RGB_pb2tflite.tflite 1 -Q_iMaxSR_RGB_385_p_pb2tflite.tflite 5 +Q_iMaxSR_RGB_385_p_pb2tflite.tflite 5.5 bloom_new_detect.tflite 3.5 bloom_model_age_gender.tflite 0.5 bloom_isface.tflite 0.5 # The output values of conv layers range from -e±5 to e±5, which almost reaches the representation limit of fp16. In # this range, the fp16 data will has big bias. And the accumulation of this bias lowers the final precision. hiai_object_detect_814.tflite 14 -ml_video_edit_video_segment_gauss_adaptis_part2_pb2tflite.tflite;2 12.1 +ml_video_edit_video_segment_gauss_adaptis_part2_pb2tflite.tflite;2 16 ml_video_edit_img_segment_adaptise_pb2tflite.tflite;2 0.5 hdc_tb_cn_neg.tflite;3 295 # The input of hiai_cv_labelDetectorModel_v3.tflite is between 0-255.