forked from mindspore-Ecosystem/mindspore
!24328 [MS][LITE][CPU] winograd optimize
Merge pull request !24328 from liuzhongkai/winograd_op
This commit is contained in:
commit
2c738757c3
|
@ -35,8 +35,10 @@ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const
|
|||
int output_tile_count = UP_DIV(output_count, tile_num);
|
||||
#ifdef ENABLE_AVX
|
||||
const int col_tile = C16NUM;
|
||||
const int tmp_data_tile = C8NUM;
|
||||
#else
|
||||
const int col_tile = C8NUM;
|
||||
const int tmp_data_tile = C4NUM;
|
||||
#endif
|
||||
int oc_tile = UP_DIV(conv_param->output_channel_, col_tile);
|
||||
int oc8 = UP_DIV(conv_param->output_channel_, C8NUM);
|
||||
|
@ -51,7 +53,7 @@ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const
|
|||
float *col_buffer = buffer_list[3];
|
||||
int trans_input_offset = tile_num * input_unit_square * in_channel;
|
||||
int gemm_out_offset = tile_num * input_unit_square * oc8 * C8NUM;
|
||||
int tmp_data_offset = input_unit_square * C4NUM;
|
||||
int tmp_data_offset = input_unit_square * tmp_data_tile;
|
||||
int col_buffer_offset = tile_num * in_channel;
|
||||
// step 1 : filter transform (pre-processed offline)
|
||||
// step 2 : input transform (online)
|
||||
|
|
|
@ -17,6 +17,256 @@
|
|||
#include "nnacl/fp32/winograd_avx.h"
|
||||
#include "nnacl/intrinsics/ms_simd_instructions.h"
|
||||
|
||||
void InputTransform4x4AvxUnit(const float *src_data, float *dst_data, const int src_step, const int dst_step,
|
||||
const int real_c) {
|
||||
if (real_c == C8NUM) {
|
||||
MS_FLOAT32X8 src[16];
|
||||
MS_FLOAT32X8 t[16];
|
||||
MS_FLOAT32X8 m[16];
|
||||
LoadAvx16Data;
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
int offset = l * 4;
|
||||
t[l] = MS_SUB256_F32(src[offset], src[2 + offset]);
|
||||
t[4 + l] = MS_ADD256_F32(src[1 + offset], src[2 + offset]);
|
||||
t[8 + l] = MS_SUB256_F32(src[2 + offset], src[1 + offset]);
|
||||
t[12 + l] = MS_SUB256_F32(src[3 + offset], src[1 + offset]);
|
||||
}
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
int offset = l * 4;
|
||||
m[l] = MS_SUB256_F32(t[offset], t[2 + offset]);
|
||||
m[4 + l] = MS_ADD256_F32(t[1 + offset], t[2 + offset]);
|
||||
m[8 + l] = MS_SUB256_F32(t[2 + offset], t[1 + offset]);
|
||||
m[12 + l] = MS_SUB256_F32(t[3 + offset], t[1 + offset]);
|
||||
}
|
||||
for (int i = 0; i < 16; i++) {
|
||||
MS_ST256_F32(dst_data + i * dst_step, m[i]);
|
||||
}
|
||||
} else {
|
||||
float src[16];
|
||||
float t[16];
|
||||
float m[16];
|
||||
for (int i = 0; i < real_c; ++i) {
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
src[j] = src_data[i + j * src_step];
|
||||
}
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
int offset = l * 4;
|
||||
t[l] = src[offset] - src[2 + offset];
|
||||
t[4 + l] = src[1 + offset] + src[2 + offset];
|
||||
t[8 + l] = src[2 + offset] - src[1 + offset];
|
||||
t[12 + l] = src[3 + offset] - src[1 + offset];
|
||||
}
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
int offset = l * 4;
|
||||
m[l] = t[offset] - t[2 + offset];
|
||||
m[4 + l] = t[1 + offset] + t[2 + offset];
|
||||
m[8 + l] = t[2 + offset] - t[1 + offset];
|
||||
m[12 + l] = t[3 + offset] - t[1 + offset];
|
||||
}
|
||||
for (int k = 0; k < 16; ++k) {
|
||||
dst_data[i + k * dst_step] = m[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void InputTransform6x6AvxUnit(const float *src_data, float *dst_data, const int src_step, const int dst_step,
|
||||
const int real_c) {
|
||||
if (real_c == C8NUM) {
|
||||
MS_FLOAT32X8 src[36];
|
||||
MS_FLOAT32X8 t[36];
|
||||
MS_FLOAT32X8 m[36];
|
||||
LoadAvx36Data;
|
||||
for (int l = 0; l < 6; ++l) {
|
||||
int offset = l * 6;
|
||||
MS_FLOAT32X8 tmp1 = MS_SUB256_F32(src[3 + offset], src[1 + offset]);
|
||||
MS_FLOAT32X8 tmp2 = MS_SUB256_F32(src[4 + offset], src[2 + offset]);
|
||||
t[l] = MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(src[offset], 4), MS_MUL256_N_F32(src[2 + offset], 5)),
|
||||
src[4 + offset]);
|
||||
t[6 + l] = MS_ADD256_F32(MS_MUL256_N_F32(MS_ADD256_F32(src[1 + offset], src[2 + offset]), -4),
|
||||
MS_ADD256_F32(src[3 + offset], src[4 + offset]));
|
||||
t[12 + l] = MS_ADD256_F32(MS_MUL256_N_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), 4),
|
||||
MS_SUB256_F32(src[4 + offset], src[3 + offset]));
|
||||
t[18 + l] = MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 2), tmp2);
|
||||
t[24 + l] = MS_ADD256_F32(MS_MUL256_N_F32(tmp1, -2), tmp2);
|
||||
t[30 + l] = MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(src[1 + offset], 4), MS_MUL256_N_F32(src[3 + offset], 5)),
|
||||
src[5 + offset]);
|
||||
}
|
||||
for (int l = 0; l < 6; ++l) {
|
||||
int offset = l * 6;
|
||||
MS_FLOAT32X8 tmp1 = MS_SUB256_F32(t[3 + offset], t[1 + offset]);
|
||||
MS_FLOAT32X8 tmp2 = MS_SUB256_F32(t[4 + offset], t[2 + offset]);
|
||||
m[l] =
|
||||
MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(t[offset], 4), MS_MUL256_N_F32(t[2 + offset], 5)), t[4 + offset]);
|
||||
m[6 + l] = MS_ADD256_F32(MS_MUL256_N_F32(MS_ADD256_F32(t[1 + offset], t[2 + offset]), -4),
|
||||
MS_ADD256_F32(t[3 + offset], t[4 + offset]));
|
||||
m[12 + l] = MS_ADD256_F32(MS_MUL256_N_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), 4),
|
||||
MS_SUB256_F32(t[4 + offset], t[3 + offset]));
|
||||
m[18 + l] = MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 2), tmp2);
|
||||
m[24 + l] = MS_ADD256_F32(MS_MUL256_N_F32(tmp1, -2), tmp2);
|
||||
m[30 + l] = MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(t[1 + offset], 4), MS_MUL256_N_F32(t[3 + offset], 5)),
|
||||
t[5 + offset]);
|
||||
}
|
||||
for (int i = 0; i < 36; i++) {
|
||||
MS_ST256_F32(dst_data + i * dst_step, m[i]);
|
||||
}
|
||||
} else {
|
||||
float src[36];
|
||||
float t[36];
|
||||
float m[36];
|
||||
for (int i = 0; i < real_c; ++i) {
|
||||
for (int j = 0; j < 36; ++j) {
|
||||
src[j] = src_data[i + j * src_step];
|
||||
}
|
||||
for (int l = 0; l < 6; ++l) {
|
||||
int offset = l * 6;
|
||||
float tmp1 = src[3 + offset] - src[1 + offset];
|
||||
float tmp2 = src[4 + offset] - src[2 + offset];
|
||||
t[l] = 4 * src[offset] - 5 * src[2 + offset] + src[4 + offset];
|
||||
t[6 + l] = -4 * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]);
|
||||
t[12 + l] = 4 * (src[1 + offset] - src[2 + offset]) + (src[4 + offset] - src[3 + offset]);
|
||||
t[18 + l] = 2 * tmp1 + tmp2;
|
||||
t[24 + l] = -2 * tmp1 + tmp2;
|
||||
t[30 + l] = 4 * src[1 + offset] - 5 * src[3 + offset] + src[5 + offset];
|
||||
}
|
||||
for (int l = 0; l < 6; ++l) {
|
||||
int offset = l * 6;
|
||||
float tmp1 = t[3 + offset] - t[1 + offset];
|
||||
float tmp2 = t[4 + offset] - t[2 + offset];
|
||||
m[l] = 4 * t[offset] - 5 * t[2 + offset] + t[4 + offset];
|
||||
m[6 + l] = -4 * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]);
|
||||
m[12 + l] = 4 * (t[1 + offset] - t[2 + offset]) + (t[4 + offset] - t[3 + offset]);
|
||||
m[18 + l] = 2 * tmp1 + tmp2;
|
||||
m[24 + l] = -2 * tmp1 + tmp2;
|
||||
m[30 + l] = 4 * t[1 + offset] - 5 * t[3 + offset] + t[5 + offset];
|
||||
}
|
||||
for (int k = 0; k < 36; ++k) {
|
||||
dst_data[i + k * dst_step] = m[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void InputTransform8x8AvxUnit_block8(const float *src_data, float *dst_data, const int src_step, const int dst_step) {
|
||||
MS_FLOAT32X8 src[64];
|
||||
MS_FLOAT32X8 t[64];
|
||||
MS_FLOAT32X8 m[64];
|
||||
LoadAvx64Data;
|
||||
for (int l = 0; l < 8; ++l) {
|
||||
int offset = l * 8;
|
||||
t[l] = MS_SUB256_F32(
|
||||
MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(src[offset], 0.5625), MS_MUL256_N_F32(src[2 + offset], 3.0625)),
|
||||
MS_MUL256_N_F32(src[4 + offset], 3.5)),
|
||||
src[6 + offset]);
|
||||
MS_FLOAT32X8 tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(src[1 + offset], 1.125), MS_MUL256_N_F32(src[5 + offset], 0.5));
|
||||
MS_FLOAT32X8 tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(src[2 + offset], 2.25), MS_MUL256_N_F32(src[4 + offset], 3.25));
|
||||
t[8 + l] =
|
||||
MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(src[3 + offset], 1.625)), src[6 + offset]);
|
||||
t[16 + l] =
|
||||
MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(src[3 + offset], 1.625)), src[6 + offset]);
|
||||
tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(src[1 + offset], 0.5625), src[5 + offset]);
|
||||
tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(src[2 + offset], 0.5625), MS_MUL256_N_F32(src[4 + offset], 2.5));
|
||||
t[24 + l] =
|
||||
MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(src[3 + offset], 2.5)), src[6 + offset]);
|
||||
t[32 + l] =
|
||||
MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(src[3 + offset], 2.5)), src[6 + offset]);
|
||||
tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(src[1 + offset], 0.375), MS_MUL256_N_F32(src[5 + offset], 1.5));
|
||||
tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(src[2 + offset], 0.25), MS_MUL256_N_F32(src[4 + offset], 1.25));
|
||||
t[40 + l] =
|
||||
MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(src[3 + offset], 1.875)), src[6 + offset]);
|
||||
t[48 + l] =
|
||||
MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(src[3 + offset], 1.875)), src[6 + offset]);
|
||||
t[56 + l] = MS_ADD256_F32(
|
||||
MS_SUB256_F32(MS_ADD256_F32(MS_MUL256_N_F32(src[1 + offset], -0.5625), MS_MUL256_N_F32(src[3 + offset], 3.0625)),
|
||||
MS_MUL256_N_F32(src[5 + offset], 3.5)),
|
||||
src[7 + offset]);
|
||||
}
|
||||
for (int l = 0; l < 8; ++l) {
|
||||
int offset = l * 8;
|
||||
m[l] = MS_SUB256_F32(
|
||||
MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(t[offset], 0.5625), MS_MUL256_N_F32(t[2 + offset], 3.0625)),
|
||||
MS_MUL256_N_F32(t[4 + offset], 3.5)),
|
||||
t[6 + offset]);
|
||||
MS_FLOAT32X8 tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(t[1 + offset], 1.125), MS_MUL256_N_F32(t[5 + offset], 0.5));
|
||||
MS_FLOAT32X8 tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(t[2 + offset], 2.25), MS_MUL256_N_F32(t[4 + offset], 3.25));
|
||||
m[8 + l] =
|
||||
MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(t[3 + offset], 1.625)), t[6 + offset]);
|
||||
m[16 + l] =
|
||||
MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(t[3 + offset], 1.625)), t[6 + offset]);
|
||||
tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(t[1 + offset], 0.5625), t[5 + offset]);
|
||||
tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(t[2 + offset], 0.5625), MS_MUL256_N_F32(t[4 + offset], 2.5));
|
||||
m[24 + l] =
|
||||
MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(t[3 + offset], 2.5)), t[6 + offset]);
|
||||
m[32 + l] =
|
||||
MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(t[3 + offset], 2.5)), t[6 + offset]);
|
||||
tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(t[1 + offset], 0.375), MS_MUL256_N_F32(t[5 + offset], 1.5));
|
||||
tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(t[2 + offset], 0.25), MS_MUL256_N_F32(t[4 + offset], 1.25));
|
||||
m[40 + l] =
|
||||
MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(t[3 + offset], 1.875)), t[6 + offset]);
|
||||
m[48 + l] =
|
||||
MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(t[3 + offset], 1.875)), t[6 + offset]);
|
||||
m[56 + l] = MS_ADD256_F32(
|
||||
MS_SUB256_F32(MS_ADD256_F32(MS_MUL256_N_F32(t[1 + offset], -0.5625), MS_MUL256_N_F32(t[3 + offset], 3.0625)),
|
||||
MS_MUL256_N_F32(t[5 + offset], 3.5)),
|
||||
t[7 + offset]);
|
||||
}
|
||||
for (int i = 0; i < 64; i++) {
|
||||
MS_ST256_F32(dst_data + i * dst_step, m[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void InputTransform8x8AvxUnit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) {
|
||||
if (real_c == C8NUM) {
|
||||
InputTransform8x8AvxUnit_block8(src_data, dst_data, src_step, dst_step);
|
||||
} else {
|
||||
float src[64];
|
||||
float t[64];
|
||||
float m[64];
|
||||
for (int i = 0; i < real_c; ++i) {
|
||||
for (int j = 0; j < 64; ++j) {
|
||||
src[j] = src_data[i + j * src_step];
|
||||
}
|
||||
for (int l = 0; l < 8; ++l) {
|
||||
int offset = l * 8;
|
||||
t[l] = 0.5625f * src[offset] - 3.0625f * src[2 + offset] + 3.5f * src[4 + offset] - src[6 + offset];
|
||||
float tmp1 = 1.125f * src[1 + offset] + 0.5f * src[5 + offset];
|
||||
float tmp2 = 2.25f * src[2 + offset] - 3.25f * src[4 + offset];
|
||||
t[8 + l] = tmp1 + tmp2 - 1.625f * src[3 + offset] + src[6 + offset];
|
||||
t[16 + l] = tmp2 - tmp1 + 1.625f * src[3 + offset] + src[6 + offset];
|
||||
tmp1 = 0.5625f * src[1 + offset] + src[5 + offset];
|
||||
tmp2 = 0.5625f * src[2 + offset] - 2.5f * src[4 + offset];
|
||||
t[24 + l] = tmp1 + tmp2 - 2.5f * src[3 + offset] + src[6 + offset];
|
||||
t[32 + l] = tmp2 - tmp1 + 2.5f * src[3 + offset] + src[6 + offset];
|
||||
tmp1 = 0.375f * src[1 + offset] + 1.5f * src[5 + offset];
|
||||
tmp2 = 0.25f * src[2 + offset] - 1.25f * src[4 + offset];
|
||||
t[40 + l] = tmp1 + tmp2 - 1.875f * src[3 + offset] + src[6 + offset];
|
||||
t[48 + l] = tmp2 - tmp1 + 1.875f * src[3 + offset] + src[6 + offset];
|
||||
t[56 + l] = -0.5625f * src[1 + offset] + 3.0625f * src[3 + offset] - 3.5f * src[5 + offset] + src[7 + offset];
|
||||
}
|
||||
for (int l = 0; l < 8; ++l) {
|
||||
int offset = l * 8;
|
||||
m[l] = 0.5625f * t[offset] - 3.0625f * t[2 + offset] + 3.5f * t[4 + offset] - t[6 + offset];
|
||||
float tmp1 = 1.125f * t[1 + offset] + 0.5f * t[5 + offset];
|
||||
float tmp2 = 2.25f * t[2 + offset] - 3.25f * t[4 + offset];
|
||||
m[8 + l] = tmp1 + tmp2 - 1.625f * t[3 + offset] + t[6 + offset];
|
||||
m[16 + l] = tmp2 - tmp1 + 1.625f * t[3 + offset] + t[6 + offset];
|
||||
tmp1 = 0.5625f * t[1 + offset] + t[5 + offset];
|
||||
tmp2 = 0.5625f * t[2 + offset] - 2.5f * t[4 + offset];
|
||||
m[24 + l] = tmp1 + tmp2 - 2.5f * t[3 + offset] + t[6 + offset];
|
||||
m[32 + l] = tmp2 - tmp1 + 2.5f * t[3 + offset] + t[6 + offset];
|
||||
tmp1 = 0.375f * t[1 + offset] + 1.5f * t[5 + offset];
|
||||
tmp2 = 0.25f * t[2 + offset] - 1.25f * t[4 + offset];
|
||||
m[40 + l] = tmp1 + tmp2 - 1.875f * t[3 + offset] + t[6 + offset];
|
||||
m[48 + l] = tmp2 - tmp1 + 1.875f * t[3 + offset] + t[6 + offset];
|
||||
m[56 + l] = -0.5625f * t[1 + offset] + 3.0625f * t[3 + offset] - 3.5f * t[5 + offset] + t[7 + offset];
|
||||
}
|
||||
for (int k = 0; k < 64; ++k) {
|
||||
dst_data[i + k * dst_step] = m[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void OutputTransform4x2AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
|
||||
int dst_step, int out_c, int r_w, int r_h, int r_c) {
|
||||
MS_FLOAT32X8 src[16];
|
||||
|
|
|
@ -210,6 +210,12 @@ typedef void (*OutputTransFunc)(const float *src_data, float *dst_data, const fl
|
|||
MS_ST256_F32(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \
|
||||
MS_ST256_F32(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]);
|
||||
|
||||
void InputTransform4x4AvxUnit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
|
||||
|
||||
void InputTransform6x6AvxUnit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
|
||||
|
||||
void InputTransform8x8AvxUnit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
|
||||
|
||||
void OutputTransform4x2AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
|
||||
int dst_step, int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform4x2ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
|
||||
|
|
|
@ -24,7 +24,12 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
|
|||
int input_unit = conv_param->input_unit_;
|
||||
int output_unit = conv_param->output_unit_;
|
||||
int in_channel = conv_param->input_channel_;
|
||||
int ic4 = UP_DIV(in_channel, C4NUM);
|
||||
#ifdef ENABLE_AVX
|
||||
int tile = C8NUM;
|
||||
#else
|
||||
int tile = C4NUM;
|
||||
#endif
|
||||
int ic4 = UP_DIV(in_channel, tile);
|
||||
int pad_h = conv_param->pad_u_;
|
||||
int pad_w = conv_param->pad_l_;
|
||||
int input_h = conv_param->input_h_;
|
||||
|
@ -45,25 +50,27 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
|
|||
int dst_plane_offset = c * in_channel;
|
||||
for (int ic = 0; ic < ic4; ic++) {
|
||||
// clear tmp buffer
|
||||
memset(tmp_data, 0, input_unit * input_unit * C4NUM * (int)(sizeof(float)));
|
||||
memset(tmp_data, 0, input_unit * input_unit * tile * (int)(sizeof(float)));
|
||||
|
||||
int real_c = in_channel - ic * C4NUM;
|
||||
real_c = real_c > C4NUM ? C4NUM : real_c;
|
||||
int src_ic4_offset = src_plane_offset + ic * C4NUM;
|
||||
int real_c = in_channel - ic * tile;
|
||||
real_c = real_c > tile ? tile : real_c;
|
||||
int src_ic4_offset = src_plane_offset + ic * tile;
|
||||
// get real input block with padding
|
||||
if (real_c == C4NUM) {
|
||||
if (real_c == tile) {
|
||||
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
|
||||
int src_y_offset = src_ic4_offset + (interval * input_w + interval_x_s) * in_channel;
|
||||
int dst_y_offset = interval * input_unit * C4NUM + interval_x_s * C4NUM;
|
||||
int dst_y_offset = interval * input_unit * tile + interval_x_s * tile;
|
||||
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
|
||||
int src_x_offset = src_y_offset + j * in_channel;
|
||||
int dst_x_offset = dst_y_offset + j * C4NUM;
|
||||
int dst_x_offset = dst_y_offset + j * tile;
|
||||
float *src_addr = (float *)(input_data) + src_x_offset;
|
||||
float *dst_addr = tmp_data + dst_x_offset;
|
||||
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
|
||||
#ifdef ENABLE_AVX
|
||||
MS_ST256_F32(dst_addr, MS_LD256_F32(src_addr));
|
||||
#elif defined(ENABLE_ARM) || defined(ENABLE_SSE)
|
||||
MS_STQ_F32(dst_addr, MS_LDQ_F32(src_addr));
|
||||
#else
|
||||
for (int k = 0; k < C4NUM; k++) {
|
||||
for (int k = 0; k < tile; k++) {
|
||||
dst_addr[k] = src_addr[k];
|
||||
}
|
||||
#endif
|
||||
|
@ -72,10 +79,10 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
|
|||
} else {
|
||||
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
|
||||
int src_y_offset = src_ic4_offset + (interval * input_w + interval_x_s) * in_channel;
|
||||
int dst_y_offset = interval * input_unit * C4NUM + interval_x_s * C4NUM;
|
||||
int dst_y_offset = interval * input_unit * tile + interval_x_s * tile;
|
||||
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
|
||||
int src_x_offset = src_y_offset + j * in_channel;
|
||||
int dst_x_offset = dst_y_offset + j * C4NUM;
|
||||
int dst_x_offset = dst_y_offset + j * tile;
|
||||
float *src_addr = (float *)(input_data) + src_x_offset;
|
||||
float *dst_addr = tmp_data + dst_x_offset;
|
||||
for (int k = 0; k < real_c; k++) {
|
||||
|
@ -86,10 +93,10 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
|
|||
}
|
||||
// input transform
|
||||
const int tile_num = C12NUM;
|
||||
int dst_ic4_offset = dst_plane_offset + ic * C4NUM;
|
||||
int dst_ic4_offset = dst_plane_offset + ic * tile;
|
||||
int dst_step = tile_num * in_channel;
|
||||
float *trans_input_ptr = trans_input + dst_ic4_offset;
|
||||
func(tmp_data, trans_input_ptr, C4NUM, dst_step, real_c);
|
||||
func(tmp_data, trans_input_ptr, tile, dst_step, real_c);
|
||||
}
|
||||
out_tile_index++;
|
||||
} // cal_tile_num loop
|
||||
|
|
|
@ -20,10 +20,10 @@
|
|||
#include "nnacl/base/conv_common_base.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
static InputTransFunc InputTransFuncList[] = {
|
||||
NULL, NULL, NULL, NULL, InputTransform4x4Unit, NULL, InputTransform6x6Unit, NULL, InputTransform8x8Unit};
|
||||
|
||||
#ifdef ENABLE_AVX
|
||||
static InputTransFunc InputTransFuncList[] = {
|
||||
NULL, NULL, NULL, NULL, InputTransform4x4AvxUnit, NULL, InputTransform6x6AvxUnit, NULL, InputTransform8x8AvxUnit};
|
||||
|
||||
static OutputTransFunc OutputTransFuncList[] = {
|
||||
OutputTransform4x2AvxUnit, OutputTransform4x3AvxUnit, OutputTransform4x2ReluAvxUnit,
|
||||
OutputTransform4x3ReluAvxUnit, OutputTransform4x2Relu6AvxUnit, OutputTransform4x3Relu6AvxUnit,
|
||||
|
@ -38,6 +38,9 @@ static OutputTransFunc OutputTransFuncList[] = {
|
|||
OutputTransform8x2Relu6AvxUnit, OutputTransform8x3Relu6AvxUnit, OutputTransform8x4Relu6AvxUnit,
|
||||
OutputTransform8x5Relu6AvxUnit, OutputTransform8x6Relu6AvxUnit, OutputTransform8x7Relu6AvxUnit};
|
||||
#else
|
||||
static InputTransFunc InputTransFuncList[] = {
|
||||
NULL, NULL, NULL, NULL, InputTransform4x4Unit, NULL, InputTransform6x6Unit, NULL, InputTransform8x8Unit};
|
||||
|
||||
static OutputTransFunc OutputTransFuncList[] = {
|
||||
OutputTransform4x2Unit, OutputTransform4x3Unit, OutputTransform4x2ReluUnit, OutputTransform4x3ReluUnit,
|
||||
OutputTransform4x2Relu6Unit, OutputTransform4x3Relu6Unit, OutputTransform6x2Unit, OutputTransform6x3Unit,
|
||||
|
|
|
@ -56,7 +56,7 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() {
|
|||
}
|
||||
|
||||
tmp_data_ = reinterpret_cast<float *>(
|
||||
ctx_->allocator->Malloc(thread_count_ * C4NUM * input_unit_ * input_unit_ * sizeof(float)));
|
||||
ctx_->allocator->Malloc(thread_count_ * tmp_data_tile_ * input_unit_ * input_unit_ * sizeof(float)));
|
||||
if (tmp_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc tmp_data_ failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
|
@ -96,8 +96,10 @@ int ConvolutionWinogradCPUKernel::Prepare() {
|
|||
tile_num_ = C12NUM;
|
||||
#ifdef ENABLE_AVX
|
||||
oc_block_ = C16NUM;
|
||||
tmp_data_tile_ = C8NUM;
|
||||
#else
|
||||
oc_block_ = C8NUM;
|
||||
tmp_data_tile_ = C4NUM;
|
||||
#endif
|
||||
kernel_unit_ = conv_param_->kernel_h_;
|
||||
input_unit_ = output_unit_ + kernel_unit_ - 1;
|
||||
|
|
|
@ -68,6 +68,7 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel {
|
|||
int output_unit_{0};
|
||||
int oc_block_{0};
|
||||
int tile_num_{0};
|
||||
int tmp_data_tile_{0};
|
||||
float *tmp_data_ = nullptr;
|
||||
float *trans_input_ = nullptr;
|
||||
float *gemm_out_ = nullptr;
|
||||
|
|
Loading…
Reference in New Issue