From 23a6ef416ea1b71681cd1c926d6097185a705416 Mon Sep 17 00:00:00 2001 From: lzk Date: Tue, 28 Sep 2021 02:37:27 -0700 Subject: [PATCH] winograd op --- .../cpu/nnacl/fp32/conv_winograd_fp32.c | 4 +- .../cpu/nnacl/fp32/winograd_avx.c | 250 ++++++++++++++++++ .../cpu/nnacl/fp32/winograd_avx.h | 6 + .../cpu/nnacl/fp32/winograd_transform.c | 35 ++- .../cpu/nnacl/fp32/winograd_utils.c | 9 +- .../arm/fp32/convolution_winograd_fp32.cc | 4 +- .../arm/fp32/convolution_winograd_fp32.h | 1 + 7 files changed, 290 insertions(+), 19 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_winograd_fp32.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_winograd_fp32.c index 2a8530d046e..65d20b9742d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_winograd_fp32.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_winograd_fp32.c @@ -35,8 +35,10 @@ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const int output_tile_count = UP_DIV(output_count, tile_num); #ifdef ENABLE_AVX const int col_tile = C16NUM; + const int tmp_data_tile = C8NUM; #else const int col_tile = C8NUM; + const int tmp_data_tile = C4NUM; #endif int oc_tile = UP_DIV(conv_param->output_channel_, col_tile); int oc8 = UP_DIV(conv_param->output_channel_, C8NUM); @@ -51,7 +53,7 @@ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const float *col_buffer = buffer_list[3]; int trans_input_offset = tile_num * input_unit_square * in_channel; int gemm_out_offset = tile_num * input_unit_square * oc8 * C8NUM; - int tmp_data_offset = input_unit_square * C4NUM; + int tmp_data_offset = input_unit_square * tmp_data_tile; int col_buffer_offset = tile_num * in_channel; // step 1 : filter transform (pre-processed offline) // step 2 : input transform (online) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_avx.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_avx.c index 356959a8613..fb106969513 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_avx.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_avx.c @@ -17,6 +17,256 @@ #include "nnacl/fp32/winograd_avx.h" #include "nnacl/intrinsics/ms_simd_instructions.h" +void InputTransform4x4AvxUnit(const float *src_data, float *dst_data, const int src_step, const int dst_step, + const int real_c) { + if (real_c == C8NUM) { + MS_FLOAT32X8 src[16]; + MS_FLOAT32X8 t[16]; + MS_FLOAT32X8 m[16]; + LoadAvx16Data; + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = MS_SUB256_F32(src[offset], src[2 + offset]); + t[4 + l] = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + t[8 + l] = MS_SUB256_F32(src[2 + offset], src[1 + offset]); + t[12 + l] = MS_SUB256_F32(src[3 + offset], src[1 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + m[l] = MS_SUB256_F32(t[offset], t[2 + offset]); + m[4 + l] = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + m[8 + l] = MS_SUB256_F32(t[2 + offset], t[1 + offset]); + m[12 + l] = MS_SUB256_F32(t[3 + offset], t[1 + offset]); + } + for (int i = 0; i < 16; i++) { + MS_ST256_F32(dst_data + i * dst_step, m[i]); + } + } else { + float src[16]; + float t[16]; + float m[16]; + for (int i = 0; i < real_c; ++i) { + for (int j = 0; j < 16; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] - src[2 + offset]; + t[4 + l] = src[1 + offset] + src[2 + offset]; + t[8 + l] = src[2 + offset] - src[1 + offset]; + t[12 + l] = src[3 + offset] - src[1 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + m[l] = t[offset] - t[2 + offset]; + m[4 + l] = t[1 + offset] + t[2 + offset]; + m[8 + l] = t[2 + offset] - t[1 + offset]; + m[12 + l] = t[3 + offset] - t[1 + offset]; + } + for (int k = 0; k < 16; ++k) { + dst_data[i + k * dst_step] = m[k]; + } + } + } +} + +void InputTransform6x6AvxUnit(const float *src_data, float *dst_data, const int src_step, const int dst_step, + const int real_c) { + if (real_c == C8NUM) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[36]; + MS_FLOAT32X8 m[36]; + LoadAvx36Data; + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_SUB256_F32(src[3 + offset], src[1 + offset]); + MS_FLOAT32X8 tmp2 = MS_SUB256_F32(src[4 + offset], src[2 + offset]); + t[l] = MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(src[offset], 4), MS_MUL256_N_F32(src[2 + offset], 5)), + src[4 + offset]); + t[6 + l] = MS_ADD256_F32(MS_MUL256_N_F32(MS_ADD256_F32(src[1 + offset], src[2 + offset]), -4), + MS_ADD256_F32(src[3 + offset], src[4 + offset])); + t[12 + l] = MS_ADD256_F32(MS_MUL256_N_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), 4), + MS_SUB256_F32(src[4 + offset], src[3 + offset])); + t[18 + l] = MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 2), tmp2); + t[24 + l] = MS_ADD256_F32(MS_MUL256_N_F32(tmp1, -2), tmp2); + t[30 + l] = MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(src[1 + offset], 4), MS_MUL256_N_F32(src[3 + offset], 5)), + src[5 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_SUB256_F32(t[3 + offset], t[1 + offset]); + MS_FLOAT32X8 tmp2 = MS_SUB256_F32(t[4 + offset], t[2 + offset]); + m[l] = + MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(t[offset], 4), MS_MUL256_N_F32(t[2 + offset], 5)), t[4 + offset]); + m[6 + l] = MS_ADD256_F32(MS_MUL256_N_F32(MS_ADD256_F32(t[1 + offset], t[2 + offset]), -4), + MS_ADD256_F32(t[3 + offset], t[4 + offset])); + m[12 + l] = MS_ADD256_F32(MS_MUL256_N_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), 4), + MS_SUB256_F32(t[4 + offset], t[3 + offset])); + m[18 + l] = MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 2), tmp2); + m[24 + l] = MS_ADD256_F32(MS_MUL256_N_F32(tmp1, -2), tmp2); + m[30 + l] = MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(t[1 + offset], 4), MS_MUL256_N_F32(t[3 + offset], 5)), + t[5 + offset]); + } + for (int i = 0; i < 36; i++) { + MS_ST256_F32(dst_data + i * dst_step, m[i]); + } + } else { + float src[36]; + float t[36]; + float m[36]; + for (int i = 0; i < real_c; ++i) { + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float tmp1 = src[3 + offset] - src[1 + offset]; + float tmp2 = src[4 + offset] - src[2 + offset]; + t[l] = 4 * src[offset] - 5 * src[2 + offset] + src[4 + offset]; + t[6 + l] = -4 * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]); + t[12 + l] = 4 * (src[1 + offset] - src[2 + offset]) + (src[4 + offset] - src[3 + offset]); + t[18 + l] = 2 * tmp1 + tmp2; + t[24 + l] = -2 * tmp1 + tmp2; + t[30 + l] = 4 * src[1 + offset] - 5 * src[3 + offset] + src[5 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float tmp1 = t[3 + offset] - t[1 + offset]; + float tmp2 = t[4 + offset] - t[2 + offset]; + m[l] = 4 * t[offset] - 5 * t[2 + offset] + t[4 + offset]; + m[6 + l] = -4 * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]); + m[12 + l] = 4 * (t[1 + offset] - t[2 + offset]) + (t[4 + offset] - t[3 + offset]); + m[18 + l] = 2 * tmp1 + tmp2; + m[24 + l] = -2 * tmp1 + tmp2; + m[30 + l] = 4 * t[1 + offset] - 5 * t[3 + offset] + t[5 + offset]; + } + for (int k = 0; k < 36; ++k) { + dst_data[i + k * dst_step] = m[k]; + } + } + } +} + +void InputTransform8x8AvxUnit_block8(const float *src_data, float *dst_data, const int src_step, const int dst_step) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[64]; + MS_FLOAT32X8 m[64]; + LoadAvx64Data; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = MS_SUB256_F32( + MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(src[offset], 0.5625), MS_MUL256_N_F32(src[2 + offset], 3.0625)), + MS_MUL256_N_F32(src[4 + offset], 3.5)), + src[6 + offset]); + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(src[1 + offset], 1.125), MS_MUL256_N_F32(src[5 + offset], 0.5)); + MS_FLOAT32X8 tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(src[2 + offset], 2.25), MS_MUL256_N_F32(src[4 + offset], 3.25)); + t[8 + l] = + MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(src[3 + offset], 1.625)), src[6 + offset]); + t[16 + l] = + MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(src[3 + offset], 1.625)), src[6 + offset]); + tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(src[1 + offset], 0.5625), src[5 + offset]); + tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(src[2 + offset], 0.5625), MS_MUL256_N_F32(src[4 + offset], 2.5)); + t[24 + l] = + MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(src[3 + offset], 2.5)), src[6 + offset]); + t[32 + l] = + MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(src[3 + offset], 2.5)), src[6 + offset]); + tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(src[1 + offset], 0.375), MS_MUL256_N_F32(src[5 + offset], 1.5)); + tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(src[2 + offset], 0.25), MS_MUL256_N_F32(src[4 + offset], 1.25)); + t[40 + l] = + MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(src[3 + offset], 1.875)), src[6 + offset]); + t[48 + l] = + MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(src[3 + offset], 1.875)), src[6 + offset]); + t[56 + l] = MS_ADD256_F32( + MS_SUB256_F32(MS_ADD256_F32(MS_MUL256_N_F32(src[1 + offset], -0.5625), MS_MUL256_N_F32(src[3 + offset], 3.0625)), + MS_MUL256_N_F32(src[5 + offset], 3.5)), + src[7 + offset]); + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + m[l] = MS_SUB256_F32( + MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(t[offset], 0.5625), MS_MUL256_N_F32(t[2 + offset], 3.0625)), + MS_MUL256_N_F32(t[4 + offset], 3.5)), + t[6 + offset]); + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(t[1 + offset], 1.125), MS_MUL256_N_F32(t[5 + offset], 0.5)); + MS_FLOAT32X8 tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(t[2 + offset], 2.25), MS_MUL256_N_F32(t[4 + offset], 3.25)); + m[8 + l] = + MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(t[3 + offset], 1.625)), t[6 + offset]); + m[16 + l] = + MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(t[3 + offset], 1.625)), t[6 + offset]); + tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(t[1 + offset], 0.5625), t[5 + offset]); + tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(t[2 + offset], 0.5625), MS_MUL256_N_F32(t[4 + offset], 2.5)); + m[24 + l] = + MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(t[3 + offset], 2.5)), t[6 + offset]); + m[32 + l] = + MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(t[3 + offset], 2.5)), t[6 + offset]); + tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(t[1 + offset], 0.375), MS_MUL256_N_F32(t[5 + offset], 1.5)); + tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(t[2 + offset], 0.25), MS_MUL256_N_F32(t[4 + offset], 1.25)); + m[40 + l] = + MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(t[3 + offset], 1.875)), t[6 + offset]); + m[48 + l] = + MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(t[3 + offset], 1.875)), t[6 + offset]); + m[56 + l] = MS_ADD256_F32( + MS_SUB256_F32(MS_ADD256_F32(MS_MUL256_N_F32(t[1 + offset], -0.5625), MS_MUL256_N_F32(t[3 + offset], 3.0625)), + MS_MUL256_N_F32(t[5 + offset], 3.5)), + t[7 + offset]); + } + for (int i = 0; i < 64; i++) { + MS_ST256_F32(dst_data + i * dst_step, m[i]); + } +} + +void InputTransform8x8AvxUnit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { + if (real_c == C8NUM) { + InputTransform8x8AvxUnit_block8(src_data, dst_data, src_step, dst_step); + } else { + float src[64]; + float t[64]; + float m[64]; + for (int i = 0; i < real_c; ++i) { + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = 0.5625f * src[offset] - 3.0625f * src[2 + offset] + 3.5f * src[4 + offset] - src[6 + offset]; + float tmp1 = 1.125f * src[1 + offset] + 0.5f * src[5 + offset]; + float tmp2 = 2.25f * src[2 + offset] - 3.25f * src[4 + offset]; + t[8 + l] = tmp1 + tmp2 - 1.625f * src[3 + offset] + src[6 + offset]; + t[16 + l] = tmp2 - tmp1 + 1.625f * src[3 + offset] + src[6 + offset]; + tmp1 = 0.5625f * src[1 + offset] + src[5 + offset]; + tmp2 = 0.5625f * src[2 + offset] - 2.5f * src[4 + offset]; + t[24 + l] = tmp1 + tmp2 - 2.5f * src[3 + offset] + src[6 + offset]; + t[32 + l] = tmp2 - tmp1 + 2.5f * src[3 + offset] + src[6 + offset]; + tmp1 = 0.375f * src[1 + offset] + 1.5f * src[5 + offset]; + tmp2 = 0.25f * src[2 + offset] - 1.25f * src[4 + offset]; + t[40 + l] = tmp1 + tmp2 - 1.875f * src[3 + offset] + src[6 + offset]; + t[48 + l] = tmp2 - tmp1 + 1.875f * src[3 + offset] + src[6 + offset]; + t[56 + l] = -0.5625f * src[1 + offset] + 3.0625f * src[3 + offset] - 3.5f * src[5 + offset] + src[7 + offset]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + m[l] = 0.5625f * t[offset] - 3.0625f * t[2 + offset] + 3.5f * t[4 + offset] - t[6 + offset]; + float tmp1 = 1.125f * t[1 + offset] + 0.5f * t[5 + offset]; + float tmp2 = 2.25f * t[2 + offset] - 3.25f * t[4 + offset]; + m[8 + l] = tmp1 + tmp2 - 1.625f * t[3 + offset] + t[6 + offset]; + m[16 + l] = tmp2 - tmp1 + 1.625f * t[3 + offset] + t[6 + offset]; + tmp1 = 0.5625f * t[1 + offset] + t[5 + offset]; + tmp2 = 0.5625f * t[2 + offset] - 2.5f * t[4 + offset]; + m[24 + l] = tmp1 + tmp2 - 2.5f * t[3 + offset] + t[6 + offset]; + m[32 + l] = tmp2 - tmp1 + 2.5f * t[3 + offset] + t[6 + offset]; + tmp1 = 0.375f * t[1 + offset] + 1.5f * t[5 + offset]; + tmp2 = 0.25f * t[2 + offset] - 1.25f * t[4 + offset]; + m[40 + l] = tmp1 + tmp2 - 1.875f * t[3 + offset] + t[6 + offset]; + m[48 + l] = tmp2 - tmp1 + 1.875f * t[3 + offset] + t[6 + offset]; + m[56 + l] = -0.5625f * t[1 + offset] + 3.0625f * t[3 + offset] - 3.5f * t[5 + offset] + t[7 + offset]; + } + for (int k = 0; k < 64; ++k) { + dst_data[i + k * dst_step] = m[k]; + } + } + } +} + void OutputTransform4x2AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { MS_FLOAT32X8 src[16]; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_avx.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_avx.h index 2bc73d940d6..d024907fd58 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_avx.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_avx.h @@ -210,6 +210,12 @@ typedef void (*OutputTransFunc)(const float *src_data, float *dst_data, const fl MS_ST256_F32(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \ MS_ST256_F32(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]); +void InputTransform4x4AvxUnit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform6x6AvxUnit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform8x8AvxUnit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + void OutputTransform4x2AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); void OutputTransform4x2ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_transform.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_transform.c index 16f20b14572..73c1b7acee6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_transform.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_transform.c @@ -24,7 +24,12 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float * int input_unit = conv_param->input_unit_; int output_unit = conv_param->output_unit_; int in_channel = conv_param->input_channel_; - int ic4 = UP_DIV(in_channel, C4NUM); +#ifdef ENABLE_AVX + int tile = C8NUM; +#else + int tile = C4NUM; +#endif + int ic4 = UP_DIV(in_channel, tile); int pad_h = conv_param->pad_u_; int pad_w = conv_param->pad_l_; int input_h = conv_param->input_h_; @@ -45,25 +50,27 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float * int dst_plane_offset = c * in_channel; for (int ic = 0; ic < ic4; ic++) { // clear tmp buffer - memset(tmp_data, 0, input_unit * input_unit * C4NUM * (int)(sizeof(float))); + memset(tmp_data, 0, input_unit * input_unit * tile * (int)(sizeof(float))); - int real_c = in_channel - ic * C4NUM; - real_c = real_c > C4NUM ? C4NUM : real_c; - int src_ic4_offset = src_plane_offset + ic * C4NUM; + int real_c = in_channel - ic * tile; + real_c = real_c > tile ? tile : real_c; + int src_ic4_offset = src_plane_offset + ic * tile; // get real input block with padding - if (real_c == C4NUM) { + if (real_c == tile) { for (int interval = interval_y_s; interval < interval_y_e; interval++) { int src_y_offset = src_ic4_offset + (interval * input_w + interval_x_s) * in_channel; - int dst_y_offset = interval * input_unit * C4NUM + interval_x_s * C4NUM; + int dst_y_offset = interval * input_unit * tile + interval_x_s * tile; for (int j = 0; j < (interval_x_e - interval_x_s); j++) { int src_x_offset = src_y_offset + j * in_channel; - int dst_x_offset = dst_y_offset + j * C4NUM; + int dst_x_offset = dst_y_offset + j * tile; float *src_addr = (float *)(input_data) + src_x_offset; float *dst_addr = tmp_data + dst_x_offset; -#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +#ifdef ENABLE_AVX + MS_ST256_F32(dst_addr, MS_LD256_F32(src_addr)); +#elif defined(ENABLE_ARM) || defined(ENABLE_SSE) MS_STQ_F32(dst_addr, MS_LDQ_F32(src_addr)); #else - for (int k = 0; k < C4NUM; k++) { + for (int k = 0; k < tile; k++) { dst_addr[k] = src_addr[k]; } #endif @@ -72,10 +79,10 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float * } else { for (int interval = interval_y_s; interval < interval_y_e; interval++) { int src_y_offset = src_ic4_offset + (interval * input_w + interval_x_s) * in_channel; - int dst_y_offset = interval * input_unit * C4NUM + interval_x_s * C4NUM; + int dst_y_offset = interval * input_unit * tile + interval_x_s * tile; for (int j = 0; j < (interval_x_e - interval_x_s); j++) { int src_x_offset = src_y_offset + j * in_channel; - int dst_x_offset = dst_y_offset + j * C4NUM; + int dst_x_offset = dst_y_offset + j * tile; float *src_addr = (float *)(input_data) + src_x_offset; float *dst_addr = tmp_data + dst_x_offset; for (int k = 0; k < real_c; k++) { @@ -86,10 +93,10 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float * } // input transform const int tile_num = C12NUM; - int dst_ic4_offset = dst_plane_offset + ic * C4NUM; + int dst_ic4_offset = dst_plane_offset + ic * tile; int dst_step = tile_num * in_channel; float *trans_input_ptr = trans_input + dst_ic4_offset; - func(tmp_data, trans_input_ptr, C4NUM, dst_step, real_c); + func(tmp_data, trans_input_ptr, tile, dst_step, real_c); } out_tile_index++; } // cal_tile_num loop diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_utils.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_utils.c index 055afca2a21..a167e117a11 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_utils.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/winograd_utils.c @@ -20,10 +20,10 @@ #include "nnacl/base/conv_common_base.h" #include "nnacl/errorcode.h" -static InputTransFunc InputTransFuncList[] = { - NULL, NULL, NULL, NULL, InputTransform4x4Unit, NULL, InputTransform6x6Unit, NULL, InputTransform8x8Unit}; - #ifdef ENABLE_AVX +static InputTransFunc InputTransFuncList[] = { + NULL, NULL, NULL, NULL, InputTransform4x4AvxUnit, NULL, InputTransform6x6AvxUnit, NULL, InputTransform8x8AvxUnit}; + static OutputTransFunc OutputTransFuncList[] = { OutputTransform4x2AvxUnit, OutputTransform4x3AvxUnit, OutputTransform4x2ReluAvxUnit, OutputTransform4x3ReluAvxUnit, OutputTransform4x2Relu6AvxUnit, OutputTransform4x3Relu6AvxUnit, @@ -38,6 +38,9 @@ static OutputTransFunc OutputTransFuncList[] = { OutputTransform8x2Relu6AvxUnit, OutputTransform8x3Relu6AvxUnit, OutputTransform8x4Relu6AvxUnit, OutputTransform8x5Relu6AvxUnit, OutputTransform8x6Relu6AvxUnit, OutputTransform8x7Relu6AvxUnit}; #else +static InputTransFunc InputTransFuncList[] = { + NULL, NULL, NULL, NULL, InputTransform4x4Unit, NULL, InputTransform6x6Unit, NULL, InputTransform8x8Unit}; + static OutputTransFunc OutputTransFuncList[] = { OutputTransform4x2Unit, OutputTransform4x3Unit, OutputTransform4x2ReluUnit, OutputTransform4x3ReluUnit, OutputTransform4x2Relu6Unit, OutputTransform4x3Relu6Unit, OutputTransform6x2Unit, OutputTransform6x3Unit, diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.cc index 692923fdc68..82e17819660 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.cc @@ -56,7 +56,7 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() { } tmp_data_ = reinterpret_cast( - ctx_->allocator->Malloc(thread_count_ * C4NUM * input_unit_ * input_unit_ * sizeof(float))); + ctx_->allocator->Malloc(thread_count_ * tmp_data_tile_ * input_unit_ * input_unit_ * sizeof(float))); if (tmp_data_ == nullptr) { MS_LOG(ERROR) << "malloc tmp_data_ failed."; return RET_MEMORY_FAILED; @@ -96,8 +96,10 @@ int ConvolutionWinogradCPUKernel::Init() { tile_num_ = C12NUM; #ifdef ENABLE_AVX oc_block_ = C16NUM; + tmp_data_tile_ = C8NUM; #else oc_block_ = C8NUM; + tmp_data_tile_ = C4NUM; #endif kernel_unit_ = conv_param_->kernel_h_; input_unit_ = output_unit_ + kernel_unit_ - 1; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.h index 15e10ecd0ca..d90283fd5b0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.h @@ -68,6 +68,7 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel { int output_unit_{0}; int oc_block_{0}; int tile_num_{0}; + int tmp_data_tile_{0}; float *tmp_data_ = nullptr; float *trans_input_ = nullptr; float *gemm_out_ = nullptr;