!24328 [MS][LITE][CPU] winograd optimize

Merge pull request !24328 from liuzhongkai/winograd_op
This commit is contained in:
i-robot 2021-10-11 01:09:33 +00:00 committed by Gitee
commit 2c738757c3
7 changed files with 290 additions and 19 deletions

View File

@ -35,8 +35,10 @@ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const
int output_tile_count = UP_DIV(output_count, tile_num);
#ifdef ENABLE_AVX
const int col_tile = C16NUM;
const int tmp_data_tile = C8NUM;
#else
const int col_tile = C8NUM;
const int tmp_data_tile = C4NUM;
#endif
int oc_tile = UP_DIV(conv_param->output_channel_, col_tile);
int oc8 = UP_DIV(conv_param->output_channel_, C8NUM);
@ -51,7 +53,7 @@ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const
float *col_buffer = buffer_list[3];
int trans_input_offset = tile_num * input_unit_square * in_channel;
int gemm_out_offset = tile_num * input_unit_square * oc8 * C8NUM;
int tmp_data_offset = input_unit_square * C4NUM;
int tmp_data_offset = input_unit_square * tmp_data_tile;
int col_buffer_offset = tile_num * in_channel;
// step 1 : filter transform (pre-processed offline)
// step 2 : input transform (online)

View File

@ -17,6 +17,256 @@
#include "nnacl/fp32/winograd_avx.h"
#include "nnacl/intrinsics/ms_simd_instructions.h"
void InputTransform4x4AvxUnit(const float *src_data, float *dst_data, const int src_step, const int dst_step,
const int real_c) {
if (real_c == C8NUM) {
MS_FLOAT32X8 src[16];
MS_FLOAT32X8 t[16];
MS_FLOAT32X8 m[16];
LoadAvx16Data;
for (int l = 0; l < 4; ++l) {
int offset = l * 4;
t[l] = MS_SUB256_F32(src[offset], src[2 + offset]);
t[4 + l] = MS_ADD256_F32(src[1 + offset], src[2 + offset]);
t[8 + l] = MS_SUB256_F32(src[2 + offset], src[1 + offset]);
t[12 + l] = MS_SUB256_F32(src[3 + offset], src[1 + offset]);
}
for (int l = 0; l < 4; ++l) {
int offset = l * 4;
m[l] = MS_SUB256_F32(t[offset], t[2 + offset]);
m[4 + l] = MS_ADD256_F32(t[1 + offset], t[2 + offset]);
m[8 + l] = MS_SUB256_F32(t[2 + offset], t[1 + offset]);
m[12 + l] = MS_SUB256_F32(t[3 + offset], t[1 + offset]);
}
for (int i = 0; i < 16; i++) {
MS_ST256_F32(dst_data + i * dst_step, m[i]);
}
} else {
float src[16];
float t[16];
float m[16];
for (int i = 0; i < real_c; ++i) {
for (int j = 0; j < 16; ++j) {
src[j] = src_data[i + j * src_step];
}
for (int l = 0; l < 4; ++l) {
int offset = l * 4;
t[l] = src[offset] - src[2 + offset];
t[4 + l] = src[1 + offset] + src[2 + offset];
t[8 + l] = src[2 + offset] - src[1 + offset];
t[12 + l] = src[3 + offset] - src[1 + offset];
}
for (int l = 0; l < 4; ++l) {
int offset = l * 4;
m[l] = t[offset] - t[2 + offset];
m[4 + l] = t[1 + offset] + t[2 + offset];
m[8 + l] = t[2 + offset] - t[1 + offset];
m[12 + l] = t[3 + offset] - t[1 + offset];
}
for (int k = 0; k < 16; ++k) {
dst_data[i + k * dst_step] = m[k];
}
}
}
}
void InputTransform6x6AvxUnit(const float *src_data, float *dst_data, const int src_step, const int dst_step,
const int real_c) {
if (real_c == C8NUM) {
MS_FLOAT32X8 src[36];
MS_FLOAT32X8 t[36];
MS_FLOAT32X8 m[36];
LoadAvx36Data;
for (int l = 0; l < 6; ++l) {
int offset = l * 6;
MS_FLOAT32X8 tmp1 = MS_SUB256_F32(src[3 + offset], src[1 + offset]);
MS_FLOAT32X8 tmp2 = MS_SUB256_F32(src[4 + offset], src[2 + offset]);
t[l] = MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(src[offset], 4), MS_MUL256_N_F32(src[2 + offset], 5)),
src[4 + offset]);
t[6 + l] = MS_ADD256_F32(MS_MUL256_N_F32(MS_ADD256_F32(src[1 + offset], src[2 + offset]), -4),
MS_ADD256_F32(src[3 + offset], src[4 + offset]));
t[12 + l] = MS_ADD256_F32(MS_MUL256_N_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), 4),
MS_SUB256_F32(src[4 + offset], src[3 + offset]));
t[18 + l] = MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 2), tmp2);
t[24 + l] = MS_ADD256_F32(MS_MUL256_N_F32(tmp1, -2), tmp2);
t[30 + l] = MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(src[1 + offset], 4), MS_MUL256_N_F32(src[3 + offset], 5)),
src[5 + offset]);
}
for (int l = 0; l < 6; ++l) {
int offset = l * 6;
MS_FLOAT32X8 tmp1 = MS_SUB256_F32(t[3 + offset], t[1 + offset]);
MS_FLOAT32X8 tmp2 = MS_SUB256_F32(t[4 + offset], t[2 + offset]);
m[l] =
MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(t[offset], 4), MS_MUL256_N_F32(t[2 + offset], 5)), t[4 + offset]);
m[6 + l] = MS_ADD256_F32(MS_MUL256_N_F32(MS_ADD256_F32(t[1 + offset], t[2 + offset]), -4),
MS_ADD256_F32(t[3 + offset], t[4 + offset]));
m[12 + l] = MS_ADD256_F32(MS_MUL256_N_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), 4),
MS_SUB256_F32(t[4 + offset], t[3 + offset]));
m[18 + l] = MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 2), tmp2);
m[24 + l] = MS_ADD256_F32(MS_MUL256_N_F32(tmp1, -2), tmp2);
m[30 + l] = MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(t[1 + offset], 4), MS_MUL256_N_F32(t[3 + offset], 5)),
t[5 + offset]);
}
for (int i = 0; i < 36; i++) {
MS_ST256_F32(dst_data + i * dst_step, m[i]);
}
} else {
float src[36];
float t[36];
float m[36];
for (int i = 0; i < real_c; ++i) {
for (int j = 0; j < 36; ++j) {
src[j] = src_data[i + j * src_step];
}
for (int l = 0; l < 6; ++l) {
int offset = l * 6;
float tmp1 = src[3 + offset] - src[1 + offset];
float tmp2 = src[4 + offset] - src[2 + offset];
t[l] = 4 * src[offset] - 5 * src[2 + offset] + src[4 + offset];
t[6 + l] = -4 * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]);
t[12 + l] = 4 * (src[1 + offset] - src[2 + offset]) + (src[4 + offset] - src[3 + offset]);
t[18 + l] = 2 * tmp1 + tmp2;
t[24 + l] = -2 * tmp1 + tmp2;
t[30 + l] = 4 * src[1 + offset] - 5 * src[3 + offset] + src[5 + offset];
}
for (int l = 0; l < 6; ++l) {
int offset = l * 6;
float tmp1 = t[3 + offset] - t[1 + offset];
float tmp2 = t[4 + offset] - t[2 + offset];
m[l] = 4 * t[offset] - 5 * t[2 + offset] + t[4 + offset];
m[6 + l] = -4 * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]);
m[12 + l] = 4 * (t[1 + offset] - t[2 + offset]) + (t[4 + offset] - t[3 + offset]);
m[18 + l] = 2 * tmp1 + tmp2;
m[24 + l] = -2 * tmp1 + tmp2;
m[30 + l] = 4 * t[1 + offset] - 5 * t[3 + offset] + t[5 + offset];
}
for (int k = 0; k < 36; ++k) {
dst_data[i + k * dst_step] = m[k];
}
}
}
}
void InputTransform8x8AvxUnit_block8(const float *src_data, float *dst_data, const int src_step, const int dst_step) {
MS_FLOAT32X8 src[64];
MS_FLOAT32X8 t[64];
MS_FLOAT32X8 m[64];
LoadAvx64Data;
for (int l = 0; l < 8; ++l) {
int offset = l * 8;
t[l] = MS_SUB256_F32(
MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(src[offset], 0.5625), MS_MUL256_N_F32(src[2 + offset], 3.0625)),
MS_MUL256_N_F32(src[4 + offset], 3.5)),
src[6 + offset]);
MS_FLOAT32X8 tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(src[1 + offset], 1.125), MS_MUL256_N_F32(src[5 + offset], 0.5));
MS_FLOAT32X8 tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(src[2 + offset], 2.25), MS_MUL256_N_F32(src[4 + offset], 3.25));
t[8 + l] =
MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(src[3 + offset], 1.625)), src[6 + offset]);
t[16 + l] =
MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(src[3 + offset], 1.625)), src[6 + offset]);
tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(src[1 + offset], 0.5625), src[5 + offset]);
tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(src[2 + offset], 0.5625), MS_MUL256_N_F32(src[4 + offset], 2.5));
t[24 + l] =
MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(src[3 + offset], 2.5)), src[6 + offset]);
t[32 + l] =
MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(src[3 + offset], 2.5)), src[6 + offset]);
tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(src[1 + offset], 0.375), MS_MUL256_N_F32(src[5 + offset], 1.5));
tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(src[2 + offset], 0.25), MS_MUL256_N_F32(src[4 + offset], 1.25));
t[40 + l] =
MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(src[3 + offset], 1.875)), src[6 + offset]);
t[48 + l] =
MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(src[3 + offset], 1.875)), src[6 + offset]);
t[56 + l] = MS_ADD256_F32(
MS_SUB256_F32(MS_ADD256_F32(MS_MUL256_N_F32(src[1 + offset], -0.5625), MS_MUL256_N_F32(src[3 + offset], 3.0625)),
MS_MUL256_N_F32(src[5 + offset], 3.5)),
src[7 + offset]);
}
for (int l = 0; l < 8; ++l) {
int offset = l * 8;
m[l] = MS_SUB256_F32(
MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(t[offset], 0.5625), MS_MUL256_N_F32(t[2 + offset], 3.0625)),
MS_MUL256_N_F32(t[4 + offset], 3.5)),
t[6 + offset]);
MS_FLOAT32X8 tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(t[1 + offset], 1.125), MS_MUL256_N_F32(t[5 + offset], 0.5));
MS_FLOAT32X8 tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(t[2 + offset], 2.25), MS_MUL256_N_F32(t[4 + offset], 3.25));
m[8 + l] =
MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(t[3 + offset], 1.625)), t[6 + offset]);
m[16 + l] =
MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(t[3 + offset], 1.625)), t[6 + offset]);
tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(t[1 + offset], 0.5625), t[5 + offset]);
tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(t[2 + offset], 0.5625), MS_MUL256_N_F32(t[4 + offset], 2.5));
m[24 + l] =
MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(t[3 + offset], 2.5)), t[6 + offset]);
m[32 + l] =
MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(t[3 + offset], 2.5)), t[6 + offset]);
tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(t[1 + offset], 0.375), MS_MUL256_N_F32(t[5 + offset], 1.5));
tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(t[2 + offset], 0.25), MS_MUL256_N_F32(t[4 + offset], 1.25));
m[40 + l] =
MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(t[3 + offset], 1.875)), t[6 + offset]);
m[48 + l] =
MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(t[3 + offset], 1.875)), t[6 + offset]);
m[56 + l] = MS_ADD256_F32(
MS_SUB256_F32(MS_ADD256_F32(MS_MUL256_N_F32(t[1 + offset], -0.5625), MS_MUL256_N_F32(t[3 + offset], 3.0625)),
MS_MUL256_N_F32(t[5 + offset], 3.5)),
t[7 + offset]);
}
for (int i = 0; i < 64; i++) {
MS_ST256_F32(dst_data + i * dst_step, m[i]);
}
}
void InputTransform8x8AvxUnit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) {
if (real_c == C8NUM) {
InputTransform8x8AvxUnit_block8(src_data, dst_data, src_step, dst_step);
} else {
float src[64];
float t[64];
float m[64];
for (int i = 0; i < real_c; ++i) {
for (int j = 0; j < 64; ++j) {
src[j] = src_data[i + j * src_step];
}
for (int l = 0; l < 8; ++l) {
int offset = l * 8;
t[l] = 0.5625f * src[offset] - 3.0625f * src[2 + offset] + 3.5f * src[4 + offset] - src[6 + offset];
float tmp1 = 1.125f * src[1 + offset] + 0.5f * src[5 + offset];
float tmp2 = 2.25f * src[2 + offset] - 3.25f * src[4 + offset];
t[8 + l] = tmp1 + tmp2 - 1.625f * src[3 + offset] + src[6 + offset];
t[16 + l] = tmp2 - tmp1 + 1.625f * src[3 + offset] + src[6 + offset];
tmp1 = 0.5625f * src[1 + offset] + src[5 + offset];
tmp2 = 0.5625f * src[2 + offset] - 2.5f * src[4 + offset];
t[24 + l] = tmp1 + tmp2 - 2.5f * src[3 + offset] + src[6 + offset];
t[32 + l] = tmp2 - tmp1 + 2.5f * src[3 + offset] + src[6 + offset];
tmp1 = 0.375f * src[1 + offset] + 1.5f * src[5 + offset];
tmp2 = 0.25f * src[2 + offset] - 1.25f * src[4 + offset];
t[40 + l] = tmp1 + tmp2 - 1.875f * src[3 + offset] + src[6 + offset];
t[48 + l] = tmp2 - tmp1 + 1.875f * src[3 + offset] + src[6 + offset];
t[56 + l] = -0.5625f * src[1 + offset] + 3.0625f * src[3 + offset] - 3.5f * src[5 + offset] + src[7 + offset];
}
for (int l = 0; l < 8; ++l) {
int offset = l * 8;
m[l] = 0.5625f * t[offset] - 3.0625f * t[2 + offset] + 3.5f * t[4 + offset] - t[6 + offset];
float tmp1 = 1.125f * t[1 + offset] + 0.5f * t[5 + offset];
float tmp2 = 2.25f * t[2 + offset] - 3.25f * t[4 + offset];
m[8 + l] = tmp1 + tmp2 - 1.625f * t[3 + offset] + t[6 + offset];
m[16 + l] = tmp2 - tmp1 + 1.625f * t[3 + offset] + t[6 + offset];
tmp1 = 0.5625f * t[1 + offset] + t[5 + offset];
tmp2 = 0.5625f * t[2 + offset] - 2.5f * t[4 + offset];
m[24 + l] = tmp1 + tmp2 - 2.5f * t[3 + offset] + t[6 + offset];
m[32 + l] = tmp2 - tmp1 + 2.5f * t[3 + offset] + t[6 + offset];
tmp1 = 0.375f * t[1 + offset] + 1.5f * t[5 + offset];
tmp2 = 0.25f * t[2 + offset] - 1.25f * t[4 + offset];
m[40 + l] = tmp1 + tmp2 - 1.875f * t[3 + offset] + t[6 + offset];
m[48 + l] = tmp2 - tmp1 + 1.875f * t[3 + offset] + t[6 + offset];
m[56 + l] = -0.5625f * t[1 + offset] + 3.0625f * t[3 + offset] - 3.5f * t[5 + offset] + t[7 + offset];
}
for (int k = 0; k < 64; ++k) {
dst_data[i + k * dst_step] = m[k];
}
}
}
}
void OutputTransform4x2AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step, int out_c, int r_w, int r_h, int r_c) {
MS_FLOAT32X8 src[16];

View File

@ -210,6 +210,12 @@ typedef void (*OutputTransFunc)(const float *src_data, float *dst_data, const fl
MS_ST256_F32(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \
MS_ST256_F32(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]);
void InputTransform4x4AvxUnit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
void InputTransform6x6AvxUnit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
void InputTransform8x8AvxUnit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
void OutputTransform4x2AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step, int out_c, int r_w, int r_h, int r_c);
void OutputTransform4x2ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,

View File

@ -24,7 +24,12 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
int input_unit = conv_param->input_unit_;
int output_unit = conv_param->output_unit_;
int in_channel = conv_param->input_channel_;
int ic4 = UP_DIV(in_channel, C4NUM);
#ifdef ENABLE_AVX
int tile = C8NUM;
#else
int tile = C4NUM;
#endif
int ic4 = UP_DIV(in_channel, tile);
int pad_h = conv_param->pad_u_;
int pad_w = conv_param->pad_l_;
int input_h = conv_param->input_h_;
@ -45,25 +50,27 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
int dst_plane_offset = c * in_channel;
for (int ic = 0; ic < ic4; ic++) {
// clear tmp buffer
memset(tmp_data, 0, input_unit * input_unit * C4NUM * (int)(sizeof(float)));
memset(tmp_data, 0, input_unit * input_unit * tile * (int)(sizeof(float)));
int real_c = in_channel - ic * C4NUM;
real_c = real_c > C4NUM ? C4NUM : real_c;
int src_ic4_offset = src_plane_offset + ic * C4NUM;
int real_c = in_channel - ic * tile;
real_c = real_c > tile ? tile : real_c;
int src_ic4_offset = src_plane_offset + ic * tile;
// get real input block with padding
if (real_c == C4NUM) {
if (real_c == tile) {
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
int src_y_offset = src_ic4_offset + (interval * input_w + interval_x_s) * in_channel;
int dst_y_offset = interval * input_unit * C4NUM + interval_x_s * C4NUM;
int dst_y_offset = interval * input_unit * tile + interval_x_s * tile;
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
int src_x_offset = src_y_offset + j * in_channel;
int dst_x_offset = dst_y_offset + j * C4NUM;
int dst_x_offset = dst_y_offset + j * tile;
float *src_addr = (float *)(input_data) + src_x_offset;
float *dst_addr = tmp_data + dst_x_offset;
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
#ifdef ENABLE_AVX
MS_ST256_F32(dst_addr, MS_LD256_F32(src_addr));
#elif defined(ENABLE_ARM) || defined(ENABLE_SSE)
MS_STQ_F32(dst_addr, MS_LDQ_F32(src_addr));
#else
for (int k = 0; k < C4NUM; k++) {
for (int k = 0; k < tile; k++) {
dst_addr[k] = src_addr[k];
}
#endif
@ -72,10 +79,10 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
} else {
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
int src_y_offset = src_ic4_offset + (interval * input_w + interval_x_s) * in_channel;
int dst_y_offset = interval * input_unit * C4NUM + interval_x_s * C4NUM;
int dst_y_offset = interval * input_unit * tile + interval_x_s * tile;
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
int src_x_offset = src_y_offset + j * in_channel;
int dst_x_offset = dst_y_offset + j * C4NUM;
int dst_x_offset = dst_y_offset + j * tile;
float *src_addr = (float *)(input_data) + src_x_offset;
float *dst_addr = tmp_data + dst_x_offset;
for (int k = 0; k < real_c; k++) {
@ -86,10 +93,10 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
}
// input transform
const int tile_num = C12NUM;
int dst_ic4_offset = dst_plane_offset + ic * C4NUM;
int dst_ic4_offset = dst_plane_offset + ic * tile;
int dst_step = tile_num * in_channel;
float *trans_input_ptr = trans_input + dst_ic4_offset;
func(tmp_data, trans_input_ptr, C4NUM, dst_step, real_c);
func(tmp_data, trans_input_ptr, tile, dst_step, real_c);
}
out_tile_index++;
} // cal_tile_num loop

View File

@ -20,10 +20,10 @@
#include "nnacl/base/conv_common_base.h"
#include "nnacl/errorcode.h"
static InputTransFunc InputTransFuncList[] = {
NULL, NULL, NULL, NULL, InputTransform4x4Unit, NULL, InputTransform6x6Unit, NULL, InputTransform8x8Unit};
#ifdef ENABLE_AVX
static InputTransFunc InputTransFuncList[] = {
NULL, NULL, NULL, NULL, InputTransform4x4AvxUnit, NULL, InputTransform6x6AvxUnit, NULL, InputTransform8x8AvxUnit};
static OutputTransFunc OutputTransFuncList[] = {
OutputTransform4x2AvxUnit, OutputTransform4x3AvxUnit, OutputTransform4x2ReluAvxUnit,
OutputTransform4x3ReluAvxUnit, OutputTransform4x2Relu6AvxUnit, OutputTransform4x3Relu6AvxUnit,
@ -38,6 +38,9 @@ static OutputTransFunc OutputTransFuncList[] = {
OutputTransform8x2Relu6AvxUnit, OutputTransform8x3Relu6AvxUnit, OutputTransform8x4Relu6AvxUnit,
OutputTransform8x5Relu6AvxUnit, OutputTransform8x6Relu6AvxUnit, OutputTransform8x7Relu6AvxUnit};
#else
static InputTransFunc InputTransFuncList[] = {
NULL, NULL, NULL, NULL, InputTransform4x4Unit, NULL, InputTransform6x6Unit, NULL, InputTransform8x8Unit};
static OutputTransFunc OutputTransFuncList[] = {
OutputTransform4x2Unit, OutputTransform4x3Unit, OutputTransform4x2ReluUnit, OutputTransform4x3ReluUnit,
OutputTransform4x2Relu6Unit, OutputTransform4x3Relu6Unit, OutputTransform6x2Unit, OutputTransform6x3Unit,

View File

@ -56,7 +56,7 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() {
}
tmp_data_ = reinterpret_cast<float *>(
ctx_->allocator->Malloc(thread_count_ * C4NUM * input_unit_ * input_unit_ * sizeof(float)));
ctx_->allocator->Malloc(thread_count_ * tmp_data_tile_ * input_unit_ * input_unit_ * sizeof(float)));
if (tmp_data_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp_data_ failed.";
return RET_MEMORY_FAILED;
@ -96,8 +96,10 @@ int ConvolutionWinogradCPUKernel::Prepare() {
tile_num_ = C12NUM;
#ifdef ENABLE_AVX
oc_block_ = C16NUM;
tmp_data_tile_ = C8NUM;
#else
oc_block_ = C8NUM;
tmp_data_tile_ = C4NUM;
#endif
kernel_unit_ = conv_param_->kernel_h_;
input_unit_ = output_unit_ + kernel_unit_ - 1;

View File

@ -68,6 +68,7 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel {
int output_unit_{0};
int oc_block_{0};
int tile_num_{0};
int tmp_data_tile_{0};
float *tmp_data_ = nullptr;
float *trans_input_ = nullptr;
float *gemm_out_ = nullptr;