optimize winograd input transform func

This commit is contained in:
fuzhiye 2020-09-28 14:50:24 +08:00
parent 8b4cdc1523
commit 2d00b74de2
7 changed files with 239 additions and 470 deletions

View File

@ -77,7 +77,6 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
int input_unit = conv_param->input_unit_;
int in_batch = conv_param->input_batch_;
int in_channel = conv_param->input_channel_;
int ic4 = UP_DIV(in_channel, C4NUM);
int out_unit = conv_param->output_unit_;
int out_w_block = UP_DIV(conv_param->output_w_, out_unit);
int out_h_block = UP_DIV(conv_param->output_h_, out_unit);
@ -96,10 +95,10 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
float *gemm_out = buffer_list[1];
float *tmp_data = buffer_list[2];
float *col_buffer = buffer_list[3];
int trans_input_offset = tile_num * input_unit_square * ic4 * C4NUM;
int trans_input_offset = tile_num * input_unit_square * in_channel;
int gemm_out_offset = tile_num * input_unit_square * oc8 * C8NUM;
int tmp_data_offset = input_unit_square * C4NUM;
int col_buffer_offset = tile_num * ic4 * C4NUM;
int col_buffer_offset = tile_num * in_channel;
// step 1 : filter transform (pre-processed offline)
// step 2 : input transform (online)
for (int b = 0; b < in_batch; b++) {
@ -107,7 +106,7 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
int out_batch_offset = b * out_channel * conv_param->output_w_ * conv_param->output_h_;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) {
int out_tile_index = thread_id * tile_num;
int cal_num = output_count - thread_id * tile_num;
int cal_num = output_count - out_tile_index;
cal_num = cal_num > tile_num ? tile_num : cal_num;
WinogradInputTransform(input_data + in_batch_offset, trans_input + task_id * trans_input_offset,
tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param,
@ -118,11 +117,11 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset;
for (int i = 0; i < input_unit_square; ++i) {
#ifdef ENABLE_ARM32
RowMajor2Col4Major(src_ptr + i * C4NUM * ic4 * C4NUM, tmp_col_ptr, C4NUM, ic4 * C4NUM);
RowMajor2Col4Major(src_ptr + i * C4NUM * in_channel, tmp_col_ptr, C4NUM, in_channel);
#else
RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM);
RowMajor2Col12Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel);
#endif
MatMulOpt(tmp_col_ptr, trans_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, ic4 * C4NUM,
MatMulOpt(tmp_col_ptr, trans_weight + i * in_channel * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, in_channel,
cal_num, oc8 * C8NUM, input_unit_square, 2);
}

View File

@ -630,26 +630,6 @@ void PackNHWC4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int c
}
}
void PackNCHWToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
int nhwc4_batch_offset = 0;
int c4 = UP_DIV(channel, C4NUM);
int nhwc4_batch_unit_offset = c4 * C4NUM * plane;
for (int b = 0; b < batch; b++) {
int batch_offset = b * channel * plane;
for (int c = 0; c < channel; c++) {
int src_c_offset = batch_offset + c * plane;
int dst_c_offset = nhwc4_batch_offset + c;
for (int i = 0; i < plane; i++) {
int src_plane_offset = src_c_offset + i;
int dst_plane_offset = dst_c_offset + i * c4 * C4NUM;
((float *)dst)[dst_plane_offset] = ((float *)src)[src_plane_offset];
}
}
nhwc4_batch_offset += nhwc4_batch_unit_offset;
}
}
void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
int c4 = UP_DIV(channel, C4NUM);
for (int b = 0; b < batch; b++) {
@ -700,105 +680,6 @@ void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int
}
}
void PackNC4HW4ToNHWCReluFp32(const void *src, void *dst, int batch, int plane, int channel) {
int c4 = UP_DIV(channel, C4NUM);
for (int b = 0; b < batch; b++) {
int src_offset = b * plane * c4 * C4NUM;
int dst_offset = b * plane * channel;
for (int k = 0; k < plane; k++) {
int src_kernel_offset = src_offset + k * C4NUM;
int dst_kernel_offset = dst_offset + k * channel;
for (int c = 0; c < c4 - 1; c++) {
int src_c_offset = src_kernel_offset + c * plane * C4NUM;
int dst_c_offset = dst_kernel_offset + c * C4NUM;
#ifdef ENABLE_NEON
float32x4_t input_ptr = vld1q_f32((float *)src + src_c_offset);
float32x4_t zero = vdupq_n_f32(0);
input_ptr = vmaxq_f32(zero, input_ptr);
vst1q_f32((float *)dst + dst_c_offset, input_ptr);
#else
for (int i = 0; i < C4NUM; ++i) {
float input_data = ((float *)src + src_c_offset)[i];
input_data = input_data < 0 ? 0 : input_data;
((float *)dst + dst_c_offset)[i] = input_data;
}
#endif
}
// res part
int res_c = channel - (c4 - 1) * C4NUM;
for (int i = 0; i < res_c; i++) {
int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i;
int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i;
float input_data = ((float *)src + src_res_c_offset)[0];
input_data = input_data < 0 ? 0 : input_data;
((float *)dst + dst_res_c_offset)[0] = input_data;
}
}
}
}
void PackNC4HW4ToNHWCRelu6Fp32(const void *src, void *dst, int batch, int plane, int channel) {
int c4 = UP_DIV(channel, C4NUM);
for (int b = 0; b < batch; b++) {
int src_offset = b * plane * c4 * C4NUM;
int dst_offset = b * plane * channel;
for (int k = 0; k < plane; k++) {
int src_kernel_offset = src_offset + k * C4NUM;
int dst_kernel_offset = dst_offset + k * channel;
for (int c = 0; c < c4 - 1; c++) {
int src_c_offset = src_kernel_offset + c * plane * C4NUM;
int dst_c_offset = dst_kernel_offset + c * C4NUM;
#ifdef ENABLE_NEON
float32x4_t input_ptr = vld1q_f32((float *)src + src_c_offset);
float32x4_t zero = vdupq_n_f32(0);
float32x4_t six = vdupq_n_f32(6);
input_ptr = vmaxq_f32(zero, input_ptr);
input_ptr = vminq_f32(six, input_ptr);
vst1q_f32((float *)dst + dst_c_offset, input_ptr);
#else
for (int i = 0; i < C4NUM; ++i) {
float input_data = ((float *)src + src_c_offset)[i];
input_data = input_data < 0 ? 0 : input_data;
input_data = input_data > 6 ? 6 : input_data;
((float *)dst + dst_c_offset)[i] = input_data;
}
#endif
}
// res part
int res_c = channel - (c4 - 1) * C4NUM;
for (int i = 0; i < res_c; i++) {
int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i;
int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i;
float input_data = ((float *)src + src_res_c_offset)[0];
input_data = input_data < 0 ? 0 : input_data;
input_data = input_data > 6 ? 6 : input_data;
((float *)dst + dst_res_c_offset)[0] = input_data;
}
}
}
}
void PackNC4HW4ToNHWCPreluFp32(const void *src, void *dst, const void *slope, int batch, int plane, int channel) {}
void PackNC4HW4ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel) {
int c4 = UP_DIV(channel, C4NUM);
for (int b = 0; b < batch; b++) {
int src_offset = b * plane * c4 * C4NUM;
int dst_offset = b * plane * channel;
for (int c = 0; c < channel; c++) {
int c4_block_num = c / C4NUM;
int c4_block_res = c % C4NUM;
int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res;
int dst_c_offset = dst_offset + c * plane;
for (int k = 0; k < plane; k++) {
int src_kernel_offset = src_c_offset + k * C4NUM;
int dst_kernel_offset = dst_c_offset + k;
((float *)dst + dst_kernel_offset)[0] = ((float *)src + src_kernel_offset)[0];
}
}
}
}
void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel) {
for (int n = 0; n < batch; n++) {
for (int hw = 0; hw < plane; hw++) {
@ -896,45 +777,6 @@ void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int c
}
}
void PackNCHWToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel) {
int nhwc4_batch_offset = 0;
int c4 = UP_DIV(channel, C4NUM);
int nhwc4_batch_unit_offset = c4 * C4NUM * plane;
for (int b = 0; b < batch; b++) {
int batch_offset = b * channel * plane;
for (int c = 0; c < channel; c++) {
int src_c_offset = batch_offset + c * plane;
int dst_c_offset = nhwc4_batch_offset + c;
for (int i = 0; i < plane; i++) {
int src_plane_offset = src_c_offset + i;
int dst_plane_offset = dst_c_offset + i * c4 * C4NUM;
((uint8_t *)dst)[dst_plane_offset] = ((uint8_t *)src)[src_plane_offset];
}
}
nhwc4_batch_offset += nhwc4_batch_unit_offset;
}
}
void PackNC4HW4ToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel) {
int c4 = UP_DIV(channel, C4NUM);
for (int b = 0; b < batch; b++) {
int src_offset = b * plane * c4 * C4NUM;
int dst_offset = b * plane * channel;
for (int c = 0; c < channel; c++) {
int c4_block_num = c / C4NUM;
int c4_block_res = c % C4NUM;
int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res;
int dst_c_offset = dst_offset + c4_block_num * C4NUM + c4_block_res;
for (int k = 0; k < plane; k++) {
int src_kernel_offset = src_c_offset + k * C4NUM;
int dst_kernel_offset = dst_c_offset + k * c4 * C4NUM;
((uint8_t *)dst + dst_kernel_offset)[0] = ((uint8_t *)src + src_kernel_offset)[0];
}
}
}
}
void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) {
int c4 = UP_DIV(channel, C4NUM);
for (int b = 0; b < batch; b++) {
@ -962,25 +804,6 @@ void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int
}
}
void PackNC4HW4ToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel) {
int c4 = UP_DIV(channel, C4NUM);
for (int b = 0; b < batch; b++) {
int src_offset = b * plane * c4 * C4NUM;
int dst_offset = b * plane * channel;
for (int c = 0; c < channel; c++) {
int c4_block_num = c / C4NUM;
int c4_block_res = c % C4NUM;
int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res;
int dst_c_offset = dst_offset + c * plane;
for (int k = 0; k < plane; k++) {
int src_kernel_offset = src_c_offset + k * C4NUM;
int dst_kernel_offset = dst_c_offset + k;
((uint8_t *)dst + dst_kernel_offset)[0] = ((uint8_t *)src + src_kernel_offset)[0];
}
}
}
}
void PackNHWCToC8HWN8Int8(const void *src, void *dst, int batch, int plane, int channel) {
for (int n = 0; n < batch; n++) {
for (int hw = 0; hw < plane; hw++) {
@ -996,25 +819,6 @@ void PackNHWCToC8HWN8Int8(const void *src, void *dst, int batch, int plane, int
return;
}
void PackNHWCToNC8HW8Int8(const void *src, void *dst, int batch, int plane, int channel) {
int c8 = UP_DIV(channel, C8NUM);
for (int b = 0; b < batch; b++) {
int src_oc_offset = b * plane * channel;
int dst_oc_offset = b * plane * c8 * C8NUM;
for (int k = 0; k < plane; k++) {
int src_kernel_offset = src_oc_offset + k * channel;
int dst_kernel_offset = dst_oc_offset + k * C8NUM;
for (int i = 0; i < channel; i++) {
int c8_block_num = i / C8NUM;
int c8_block_rem = i % C8NUM;
int src_ic_offset = src_kernel_offset + i;
int dst_ic_offset = dst_kernel_offset + c8_block_num * plane * C8NUM + c8_block_rem;
((int8_t *)dst + dst_ic_offset)[0] = ((int8_t *)src + src_ic_offset)[0];
}
}
}
}
void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) {
for (int n = 0; n < batch; n++) {
for (int c = 0; c < channel; c++) {
@ -1231,27 +1035,6 @@ void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int ch
return PackNHWCToNCHWFp32(src, dst, batch, channel, plane);
}
void MatrixPackUnit(const float *src, float *dst, size_t row, size_t col, size_t src_stride, size_t dst_stride) {
size_t copy_size = row * C4NUM * sizeof(float);
for (int c = 0; c < col; c++) {
memcpy(dst + c * dst_stride, src + c * src_stride, copy_size);
}
}
void MatrixPack(const float *src, float *dst, int row, int ic4, int stride) {
int row4mod = row % 4;
int row4div = row / 4;
for (int i = 0; i < row4div; i++) {
MatrixPackUnit(src + i * 4 * 4, dst + i * 4 * ic4 * 4, 4, ic4, stride, 16);
}
if (row4mod > 0) {
MatrixPackUnit(src + row4div * 4 * 4, dst + row4div * 4 * ic4 * 4, row4mod, ic4, stride, row4mod * 4);
}
return;
}
void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter *conv_param) {
int input_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_;
int ic4 = UP_DIV(conv_param->input_channel_, C4NUM);

View File

@ -46,11 +46,6 @@ void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParam
void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, int32_t *filter_zp, ConvParameter *conv_param);
void PackInputSum8x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel,
size_t plane_size, ConvParameter *conv_param);
void MatrixPack(const float *src, float *dst, int row, int ic4, int stride);
void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, ConvParameter *conv_param);
void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel);
@ -75,20 +70,10 @@ void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int ch
void PackNHWC4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNCHWToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNC4HW4ToNHWCReluFp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNC4HW4ToNHWCRelu6Fp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNC4HW4ToNHWCPreluFp32(const void *src, void *dst, const void *slope, int batch, int plane, int channel);
void PackNC4HW4ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel);
@ -99,18 +84,10 @@ void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int c
void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
void PackNCHWToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel);
void PackNC4HW4ToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel);
void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
void PackNC4HW4ToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToC8HWN8Int8(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToNC8HW8Int8(const void *src, void *dst, int batch, int plane, int channel);
void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter *conv_param);

View File

@ -42,7 +42,7 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s);
int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s);
int dst_plane_offset = c * C4NUM * ic4;
int dst_plane_offset = c * in_channel;
for (int ic = 0; ic < ic4; ic++) {
// clear tmp buffer
memset(tmp_data, 0, input_unit * input_unit * C4NUM * sizeof(float));
@ -91,9 +91,9 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
const int tile_num = 12;
#endif
int dst_ic4_offset = dst_plane_offset + ic * C4NUM;
size_t dst_step = tile_num * ic4 * C4NUM;
size_t dst_step = tile_num * in_channel;
float *trans_input_ptr = trans_input + dst_ic4_offset;
func(tmp_data, trans_input_ptr, C4NUM, dst_step);
func(tmp_data, trans_input_ptr, C4NUM, dst_step, real_c);
}
out_tile_index++;
} // cal_tile_num loop

View File

@ -171,227 +171,241 @@ void GeneralOutputTransformUnit(const float *src_data, float *dst_data, const fl
InputTransFunc GetInputTransFunc(int input_unit) { return InputTransFuncList[input_unit]; }
void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step) {
void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) {
#ifdef ENABLE_ARM
float32x4_t src[16];
float32x4_t t[16];
float32x4_t m[16];
Load16Data;
for (int l = 0; l < 4; ++l) {
int offset = l * 4;
t[l] = vsubq_f32(src[offset], src[2 + offset]);
t[4 + l] = vaddq_f32(src[1 + offset], src[2 + offset]);
t[8 + l] = vsubq_f32(src[2 + offset], src[1 + offset]);
t[12 + l] = vsubq_f32(src[3 + offset], src[1 + offset]);
}
for (int l = 0; l < 4; ++l) {
int offset = l * 4;
m[l] = vsubq_f32(t[offset], t[2 + offset]);
m[4 + l] = vaddq_f32(t[1 + offset], t[2 + offset]);
m[8 + l] = vsubq_f32(t[2 + offset], t[1 + offset]);
m[12 + l] = vsubq_f32(t[3 + offset], t[1 + offset]);
}
for (int i = 0; i < 16; i++) {
vst1q_f32(dst_data + i * dst_step, m[i]);
}
#else
float src[16];
float t[16];
float m[16];
for (int i = 0; i < C4NUM; ++i) {
for (int j = 0; j < 16; ++j) {
src[j] = src_data[i + j * src_step];
if (real_c == 4) {
float32x4_t src[16];
float32x4_t t[16];
float32x4_t m[16];
Load16Data;
for (int l = 0; l < 4; ++l) {
int offset = l * 4;
t[l] = vsubq_f32(src[offset], src[2 + offset]);
t[4 + l] = vaddq_f32(src[1 + offset], src[2 + offset]);
t[8 + l] = vsubq_f32(src[2 + offset], src[1 + offset]);
t[12 + l] = vsubq_f32(src[3 + offset], src[1 + offset]);
}
for (int l = 0; l < 4; ++l) {
int offset = l * 4;
t[l] = src[offset] - src[2 + offset];
t[4 + l] = src[1 + offset] + src[2 + offset];
t[8 + l] = src[2 + offset] - src[1 + offset];
t[12 + l] = src[3 + offset] - src[1 + offset];
m[l] = vsubq_f32(t[offset], t[2 + offset]);
m[4 + l] = vaddq_f32(t[1 + offset], t[2 + offset]);
m[8 + l] = vsubq_f32(t[2 + offset], t[1 + offset]);
m[12 + l] = vsubq_f32(t[3 + offset], t[1 + offset]);
}
for (int l = 0; l < 4; ++l) {
int offset = l * 4;
m[l] = t[offset] - t[2 + offset];
m[4 + l] = t[1 + offset] + t[2 + offset];
m[8 + l] = t[2 + offset] - t[1 + offset];
m[12 + l] = t[3 + offset] - t[1 + offset];
for (int i = 0; i < 16; i++) {
vst1q_f32(dst_data + i * dst_step, m[i]);
}
for (int k = 0; k < 16; ++k) {
dst_data[i + k * dst_step] = m[k];
} else {
#endif
float src[16];
float t[16];
float m[16];
for (int i = 0; i < real_c; ++i) {
for (int j = 0; j < 16; ++j) {
src[j] = src_data[i + j * src_step];
}
for (int l = 0; l < 4; ++l) {
int offset = l * 4;
t[l] = src[offset] - src[2 + offset];
t[4 + l] = src[1 + offset] + src[2 + offset];
t[8 + l] = src[2 + offset] - src[1 + offset];
t[12 + l] = src[3 + offset] - src[1 + offset];
}
for (int l = 0; l < 4; ++l) {
int offset = l * 4;
m[l] = t[offset] - t[2 + offset];
m[4 + l] = t[1 + offset] + t[2 + offset];
m[8 + l] = t[2 + offset] - t[1 + offset];
m[12 + l] = t[3 + offset] - t[1 + offset];
}
for (int k = 0; k < 16; ++k) {
dst_data[i + k * dst_step] = m[k];
}
}
#ifdef ENABLE_ARM
}
#endif
}
void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, int dst_step) {
void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) {
#ifdef ENABLE_ARM
float32x4_t src[36];
float32x4_t t[36];
float32x4_t m[36];
Load36Data;
for (int l = 0; l < 6; ++l) {
int offset = l * 6;
float32x4_t tmp1 = vsubq_f32(src[3 + offset], src[1 + offset]);
float32x4_t tmp2 = vsubq_f32(src[4 + offset], src[2 + offset]);
t[l] = vaddq_f32(vsubq_f32(vmulq_n_f32(src[offset], 4), vmulq_n_f32(src[2 + offset], 5)), src[4 + offset]);
t[6 + l] = vaddq_f32(vmulq_n_f32(vaddq_f32(src[1 + offset], src[2 + offset]), -4),
vaddq_f32(src[3 + offset], src[4 + offset]));
t[12 + l] = vaddq_f32(vmulq_n_f32(vsubq_f32(src[1 + offset], src[2 + offset]), 4),
vsubq_f32(src[4 + offset], src[3 + offset]));
t[18 + l] = vaddq_f32(vmulq_n_f32(tmp1, 2), tmp2);
t[24 + l] = vaddq_f32(vmulq_n_f32(tmp1, -2), tmp2);
t[30 + l] = vaddq_f32(vsubq_f32(vmulq_n_f32(src[1 + offset], 4), vmulq_n_f32(src[3 + offset], 5)), src[5 + offset]);
}
for (int l = 0; l < 6; ++l) {
int offset = l * 6;
float32x4_t tmp1 = vsubq_f32(t[3 + offset], t[1 + offset]);
float32x4_t tmp2 = vsubq_f32(t[4 + offset], t[2 + offset]);
m[l] = vaddq_f32(vsubq_f32(vmulq_n_f32(t[offset], 4), vmulq_n_f32(t[2 + offset], 5)), t[4 + offset]);
m[6 + l] =
vaddq_f32(vmulq_n_f32(vaddq_f32(t[1 + offset], t[2 + offset]), -4), vaddq_f32(t[3 + offset], t[4 + offset]));
m[12 + l] =
vaddq_f32(vmulq_n_f32(vsubq_f32(t[1 + offset], t[2 + offset]), 4), vsubq_f32(t[4 + offset], t[3 + offset]));
m[18 + l] = vaddq_f32(vmulq_n_f32(tmp1, 2), tmp2);
m[24 + l] = vaddq_f32(vmulq_n_f32(tmp1, -2), tmp2);
m[30 + l] = vaddq_f32(vsubq_f32(vmulq_n_f32(t[1 + offset], 4), vmulq_n_f32(t[3 + offset], 5)), t[5 + offset]);
}
for (int i = 0; i < 36; i++) {
vst1q_f32(dst_data + i * dst_step, m[i]);
}
#else
float src[36];
float t[36];
float m[36];
for (int i = 0; i < C4NUM; ++i) {
for (int j = 0; j < 36; ++j) {
src[j] = src_data[i + j * src_step];
if (real_c == 4) {
float32x4_t src[36];
float32x4_t t[36];
float32x4_t m[36];
Load36Data;
for (int l = 0; l < 6; ++l) {
int offset = l * 6;
float32x4_t tmp1 = vsubq_f32(src[3 + offset], src[1 + offset]);
float32x4_t tmp2 = vsubq_f32(src[4 + offset], src[2 + offset]);
t[l] = vaddq_f32(vsubq_f32(vmulq_n_f32(src[offset], 4), vmulq_n_f32(src[2 + offset], 5)), src[4 + offset]);
t[6 + l] = vaddq_f32(vmulq_n_f32(vaddq_f32(src[1 + offset], src[2 + offset]), -4),
vaddq_f32(src[3 + offset], src[4 + offset]));
t[12 + l] = vaddq_f32(vmulq_n_f32(vsubq_f32(src[1 + offset], src[2 + offset]), 4),
vsubq_f32(src[4 + offset], src[3 + offset]));
t[18 + l] = vaddq_f32(vmulq_n_f32(tmp1, 2), tmp2);
t[24 + l] = vaddq_f32(vmulq_n_f32(tmp1, -2), tmp2);
t[30 + l] =
vaddq_f32(vsubq_f32(vmulq_n_f32(src[1 + offset], 4), vmulq_n_f32(src[3 + offset], 5)), src[5 + offset]);
}
for (int l = 0; l < 6; ++l) {
int offset = l * 6;
float tmp1 = src[3 + offset] - src[1 + offset];
float tmp2 = src[4 + offset] - src[2 + offset];
t[l] = 4 * src[offset] - 5 * src[2 + offset] + src[4 + offset];
t[6 + l] = -4 * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]);
t[12 + l] = 4 * (src[1 + offset] - src[2 + offset]) + (src[4 + offset] - src[3 + offset]);
t[18 + l] = 2 * tmp1 + tmp2;
t[24 + l] = -2 * tmp1 + tmp2;
t[30 + l] = 4 * src[1 + offset] - 5 * src[3 + offset] + src[5 + offset];
float32x4_t tmp1 = vsubq_f32(t[3 + offset], t[1 + offset]);
float32x4_t tmp2 = vsubq_f32(t[4 + offset], t[2 + offset]);
m[l] = vaddq_f32(vsubq_f32(vmulq_n_f32(t[offset], 4), vmulq_n_f32(t[2 + offset], 5)), t[4 + offset]);
m[6 + l] =
vaddq_f32(vmulq_n_f32(vaddq_f32(t[1 + offset], t[2 + offset]), -4), vaddq_f32(t[3 + offset], t[4 + offset]));
m[12 + l] =
vaddq_f32(vmulq_n_f32(vsubq_f32(t[1 + offset], t[2 + offset]), 4), vsubq_f32(t[4 + offset], t[3 + offset]));
m[18 + l] = vaddq_f32(vmulq_n_f32(tmp1, 2), tmp2);
m[24 + l] = vaddq_f32(vmulq_n_f32(tmp1, -2), tmp2);
m[30 + l] = vaddq_f32(vsubq_f32(vmulq_n_f32(t[1 + offset], 4), vmulq_n_f32(t[3 + offset], 5)), t[5 + offset]);
}
for (int l = 0; l < 6; ++l) {
int offset = l * 6;
float tmp1 = t[3 + offset] - t[1 + offset];
float tmp2 = t[4 + offset] - t[2 + offset];
m[l] = 4 * t[offset] - 5 * t[2 + offset] + t[4 + offset];
m[6 + l] = -4 * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]);
m[12 + l] = 4 * (t[1 + offset] - t[2 + offset]) + (t[4 + offset] - t[3 + offset]);
m[18 + l] = 2 * tmp1 + tmp2;
m[24 + l] = -2 * tmp1 + tmp2;
m[30 + l] = 4 * t[1 + offset] - 5 * t[3 + offset] + t[5 + offset];
for (int i = 0; i < 36; i++) {
vst1q_f32(dst_data + i * dst_step, m[i]);
}
for (int k = 0; k < 36; ++k) {
dst_data[i + k * dst_step] = m[k];
} else {
#endif
float src[36];
float t[36];
float m[36];
for (int i = 0; i < real_c; ++i) {
for (int j = 0; j < 36; ++j) {
src[j] = src_data[i + j * src_step];
}
for (int l = 0; l < 6; ++l) {
int offset = l * 6;
float tmp1 = src[3 + offset] - src[1 + offset];
float tmp2 = src[4 + offset] - src[2 + offset];
t[l] = 4 * src[offset] - 5 * src[2 + offset] + src[4 + offset];
t[6 + l] = -4 * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]);
t[12 + l] = 4 * (src[1 + offset] - src[2 + offset]) + (src[4 + offset] - src[3 + offset]);
t[18 + l] = 2 * tmp1 + tmp2;
t[24 + l] = -2 * tmp1 + tmp2;
t[30 + l] = 4 * src[1 + offset] - 5 * src[3 + offset] + src[5 + offset];
}
for (int l = 0; l < 6; ++l) {
int offset = l * 6;
float tmp1 = t[3 + offset] - t[1 + offset];
float tmp2 = t[4 + offset] - t[2 + offset];
m[l] = 4 * t[offset] - 5 * t[2 + offset] + t[4 + offset];
m[6 + l] = -4 * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]);
m[12 + l] = 4 * (t[1 + offset] - t[2 + offset]) + (t[4 + offset] - t[3 + offset]);
m[18 + l] = 2 * tmp1 + tmp2;
m[24 + l] = -2 * tmp1 + tmp2;
m[30 + l] = 4 * t[1 + offset] - 5 * t[3 + offset] + t[5 + offset];
}
for (int k = 0; k < 36; ++k) {
dst_data[i + k * dst_step] = m[k];
}
}
#ifdef ENABLE_ARM
}
#endif
}
void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step) {
void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) {
#ifdef ENABLE_ARM
float32x4_t src[64];
float32x4_t t[64];
float32x4_t m[64];
Load64Data;
for (int l = 0; l < 8; ++l) {
int offset = l * 8;
t[l] = vsubq_f32(vaddq_f32(vsubq_f32(vmulq_n_f32(src[offset], 0.5625), vmulq_n_f32(src[2 + offset], 3.0625)),
vmulq_n_f32(src[4 + offset], 3.5)),
src[6 + offset]);
float32x4_t tmp1 = vaddq_f32(vmulq_n_f32(src[1 + offset], 1.125), vmulq_n_f32(src[5 + offset], 0.5));
float32x4_t tmp2 = vsubq_f32(vmulq_n_f32(src[2 + offset], 2.25), vmulq_n_f32(src[4 + offset], 3.25));
t[8 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(src[3 + offset], 1.625)), src[6 + offset]);
t[16 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(src[3 + offset], 1.625)), src[6 + offset]);
tmp1 = vaddq_f32(vmulq_n_f32(src[1 + offset], 0.5625), src[5 + offset]);
tmp2 = vsubq_f32(vmulq_n_f32(src[2 + offset], 0.5625), vmulq_n_f32(src[4 + offset], 2.5));
t[24 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(src[3 + offset], 2.5)), src[6 + offset]);
t[32 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(src[3 + offset], 2.5)), src[6 + offset]);
tmp1 = vaddq_f32(vmulq_n_f32(src[1 + offset], 0.375), vmulq_n_f32(src[5 + offset], 1.5));
tmp2 = vsubq_f32(vmulq_n_f32(src[2 + offset], 0.25), vmulq_n_f32(src[4 + offset], 1.25));
t[40 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(src[3 + offset], 1.875)), src[6 + offset]);
t[48 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(src[3 + offset], 1.875)), src[6 + offset]);
t[56 + l] =
vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src[1 + offset], -0.5625), vmulq_n_f32(src[3 + offset], 3.0625)),
vmulq_n_f32(src[5 + offset], 3.5)),
src[7 + offset]);
}
for (int l = 0; l < 8; ++l) {
int offset = l * 8;
m[l] = vsubq_f32(vaddq_f32(vsubq_f32(vmulq_n_f32(t[offset], 0.5625), vmulq_n_f32(t[2 + offset], 3.0625)),
vmulq_n_f32(t[4 + offset], 3.5)),
t[6 + offset]);
float32x4_t tmp1 = vaddq_f32(vmulq_n_f32(t[1 + offset], 1.125), vmulq_n_f32(t[5 + offset], 0.5));
float32x4_t tmp2 = vsubq_f32(vmulq_n_f32(t[2 + offset], 2.25), vmulq_n_f32(t[4 + offset], 3.25));
m[8 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(t[3 + offset], 1.625)), t[6 + offset]);
m[16 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(t[3 + offset], 1.625)), t[6 + offset]);
tmp1 = vaddq_f32(vmulq_n_f32(t[1 + offset], 0.5625), t[5 + offset]);
tmp2 = vsubq_f32(vmulq_n_f32(t[2 + offset], 0.5625), vmulq_n_f32(t[4 + offset], 2.5));
m[24 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(t[3 + offset], 2.5)), t[6 + offset]);
m[32 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(t[3 + offset], 2.5)), t[6 + offset]);
tmp1 = vaddq_f32(vmulq_n_f32(t[1 + offset], 0.375), vmulq_n_f32(t[5 + offset], 1.5));
tmp2 = vsubq_f32(vmulq_n_f32(t[2 + offset], 0.25), vmulq_n_f32(t[4 + offset], 1.25));
m[40 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(t[3 + offset], 1.875)), t[6 + offset]);
m[48 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(t[3 + offset], 1.875)), t[6 + offset]);
m[56 + l] = vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t[1 + offset], -0.5625), vmulq_n_f32(t[3 + offset], 3.0625)),
vmulq_n_f32(t[5 + offset], 3.5)),
t[7 + offset]);
}
for (int i = 0; i < 64; i++) {
vst1q_f32(dst_data + i * dst_step, m[i]);
}
#else
float src[64];
float t[64];
float m[64];
for (int i = 0; i < C4NUM; ++i) {
for (int j = 0; j < 64; ++j) {
src[j] = src_data[i + j * src_step];
if (real_c == 4) {
float32x4_t src[64];
float32x4_t t[64];
float32x4_t m[64];
Load64Data;
for (int l = 0; l < 8; ++l) {
int offset = l * 8;
t[l] = vsubq_f32(vaddq_f32(vsubq_f32(vmulq_n_f32(src[offset], 0.5625), vmulq_n_f32(src[2 + offset], 3.0625)),
vmulq_n_f32(src[4 + offset], 3.5)),
src[6 + offset]);
float32x4_t tmp1 = vaddq_f32(vmulq_n_f32(src[1 + offset], 1.125), vmulq_n_f32(src[5 + offset], 0.5));
float32x4_t tmp2 = vsubq_f32(vmulq_n_f32(src[2 + offset], 2.25), vmulq_n_f32(src[4 + offset], 3.25));
t[8 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(src[3 + offset], 1.625)), src[6 + offset]);
t[16 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(src[3 + offset], 1.625)), src[6 + offset]);
tmp1 = vaddq_f32(vmulq_n_f32(src[1 + offset], 0.5625), src[5 + offset]);
tmp2 = vsubq_f32(vmulq_n_f32(src[2 + offset], 0.5625), vmulq_n_f32(src[4 + offset], 2.5));
t[24 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(src[3 + offset], 2.5)), src[6 + offset]);
t[32 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(src[3 + offset], 2.5)), src[6 + offset]);
tmp1 = vaddq_f32(vmulq_n_f32(src[1 + offset], 0.375), vmulq_n_f32(src[5 + offset], 1.5));
tmp2 = vsubq_f32(vmulq_n_f32(src[2 + offset], 0.25), vmulq_n_f32(src[4 + offset], 1.25));
t[40 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(src[3 + offset], 1.875)), src[6 + offset]);
t[48 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(src[3 + offset], 1.875)), src[6 + offset]);
t[56 + l] =
vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src[1 + offset], -0.5625), vmulq_n_f32(src[3 + offset], 3.0625)),
vmulq_n_f32(src[5 + offset], 3.5)),
src[7 + offset]);
}
for (int l = 0; l < 8; ++l) {
int offset = l * 8;
t[l] = 0.5625f * src[offset] - 3.0625f * src[2 + offset] + 3.5f * src[4 + offset] - src[6 + offset];
float tmp1 = 1.125f * src[1 + offset] + 0.5f * src[5 + offset];
float tmp2 = 2.25f * src[2 + offset] - 3.25f * src[4 + offset];
t[8 + l] = tmp1 + tmp2 - 1.625f * src[3 + offset] + src[6 + offset];
t[16 + l] = tmp2 - tmp1 + 1.625f * src[3 + offset] + src[6 + offset];
tmp1 = 0.5625f * src[1 + offset] + src[5 + offset];
tmp2 = 0.5625f * src[2 + offset] - 2.5f * src[4 + offset];
t[24 + l] = tmp1 + tmp2 - 2.5f * src[3 + offset] + src[6 + offset];
t[32 + l] = tmp2 - tmp1 + 2.5f * src[3 + offset] + src[6 + offset];
tmp1 = 0.375f * src[1 + offset] + 1.5f * src[5 + offset];
tmp2 = 0.25f * src[2 + offset] - 1.25f * src[4 + offset];
t[40 + l] = tmp1 + tmp2 - 1.875f * src[3 + offset] + src[6 + offset];
t[48 + l] = tmp2 - tmp1 + 1.875f * src[3 + offset] + src[6 + offset];
t[56 + l] = -0.5625f * src[1 + offset] + 3.0625f * src[3 + offset] - 3.5f * src[5 + offset] + src[7 + offset];
m[l] = vsubq_f32(vaddq_f32(vsubq_f32(vmulq_n_f32(t[offset], 0.5625), vmulq_n_f32(t[2 + offset], 3.0625)),
vmulq_n_f32(t[4 + offset], 3.5)),
t[6 + offset]);
float32x4_t tmp1 = vaddq_f32(vmulq_n_f32(t[1 + offset], 1.125), vmulq_n_f32(t[5 + offset], 0.5));
float32x4_t tmp2 = vsubq_f32(vmulq_n_f32(t[2 + offset], 2.25), vmulq_n_f32(t[4 + offset], 3.25));
m[8 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(t[3 + offset], 1.625)), t[6 + offset]);
m[16 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(t[3 + offset], 1.625)), t[6 + offset]);
tmp1 = vaddq_f32(vmulq_n_f32(t[1 + offset], 0.5625), t[5 + offset]);
tmp2 = vsubq_f32(vmulq_n_f32(t[2 + offset], 0.5625), vmulq_n_f32(t[4 + offset], 2.5));
m[24 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(t[3 + offset], 2.5)), t[6 + offset]);
m[32 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(t[3 + offset], 2.5)), t[6 + offset]);
tmp1 = vaddq_f32(vmulq_n_f32(t[1 + offset], 0.375), vmulq_n_f32(t[5 + offset], 1.5));
tmp2 = vsubq_f32(vmulq_n_f32(t[2 + offset], 0.25), vmulq_n_f32(t[4 + offset], 1.25));
m[40 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(t[3 + offset], 1.875)), t[6 + offset]);
m[48 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(t[3 + offset], 1.875)), t[6 + offset]);
m[56 + l] =
vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t[1 + offset], -0.5625), vmulq_n_f32(t[3 + offset], 3.0625)),
vmulq_n_f32(t[5 + offset], 3.5)),
t[7 + offset]);
}
for (int l = 0; l < 8; ++l) {
int offset = l * 8;
m[l] = 0.5625f * t[offset] - 3.0625f * t[2 + offset] + 3.5f * t[4 + offset] - t[6 + offset];
float tmp1 = 1.125f * t[1 + offset] + 0.5f * t[5 + offset];
float tmp2 = 2.25f * t[2 + offset] - 3.25f * t[4 + offset];
m[8 + l] = tmp1 + tmp2 - 1.625f * t[3 + offset] + t[6 + offset];
m[16 + l] = tmp2 - tmp1 + 1.625f * t[3 + offset] + t[6 + offset];
tmp1 = 0.5625f * t[1 + offset] + t[5 + offset];
tmp2 = 0.5625f * t[2 + offset] - 2.5f * t[4 + offset];
m[24 + l] = tmp1 + tmp2 - 2.5f * t[3 + offset] + t[6 + offset];
m[32 + l] = tmp2 - tmp1 + 2.5f * t[3 + offset] + t[6 + offset];
tmp1 = 0.375f * t[1 + offset] + 1.5f * t[5 + offset];
tmp2 = 0.25f * t[2 + offset] - 1.25f * t[4 + offset];
m[40 + l] = tmp1 + tmp2 - 1.875f * t[3 + offset] + t[6 + offset];
m[48 + l] = tmp2 - tmp1 + 1.875f * t[3 + offset] + t[6 + offset];
m[56 + l] = -0.5625f * t[1 + offset] + 3.0625f * t[3 + offset] - 3.5f * t[5 + offset] + t[7 + offset];
for (int i = 0; i < 64; i++) {
vst1q_f32(dst_data + i * dst_step, m[i]);
}
for (int k = 0; k < 64; ++k) {
dst_data[i + k * dst_step] = m[k];
} else {
#endif
float src[64];
float t[64];
float m[64];
for (int i = 0; i < real_c; ++i) {
for (int j = 0; j < 64; ++j) {
src[j] = src_data[i + j * src_step];
}
for (int l = 0; l < 8; ++l) {
int offset = l * 8;
t[l] = 0.5625f * src[offset] - 3.0625f * src[2 + offset] + 3.5f * src[4 + offset] - src[6 + offset];
float tmp1 = 1.125f * src[1 + offset] + 0.5f * src[5 + offset];
float tmp2 = 2.25f * src[2 + offset] - 3.25f * src[4 + offset];
t[8 + l] = tmp1 + tmp2 - 1.625f * src[3 + offset] + src[6 + offset];
t[16 + l] = tmp2 - tmp1 + 1.625f * src[3 + offset] + src[6 + offset];
tmp1 = 0.5625f * src[1 + offset] + src[5 + offset];
tmp2 = 0.5625f * src[2 + offset] - 2.5f * src[4 + offset];
t[24 + l] = tmp1 + tmp2 - 2.5f * src[3 + offset] + src[6 + offset];
t[32 + l] = tmp2 - tmp1 + 2.5f * src[3 + offset] + src[6 + offset];
tmp1 = 0.375f * src[1 + offset] + 1.5f * src[5 + offset];
tmp2 = 0.25f * src[2 + offset] - 1.25f * src[4 + offset];
t[40 + l] = tmp1 + tmp2 - 1.875f * src[3 + offset] + src[6 + offset];
t[48 + l] = tmp2 - tmp1 + 1.875f * src[3 + offset] + src[6 + offset];
t[56 + l] = -0.5625f * src[1 + offset] + 3.0625f * src[3 + offset] - 3.5f * src[5 + offset] + src[7 + offset];
}
for (int l = 0; l < 8; ++l) {
int offset = l * 8;
m[l] = 0.5625f * t[offset] - 3.0625f * t[2 + offset] + 3.5f * t[4 + offset] - t[6 + offset];
float tmp1 = 1.125f * t[1 + offset] + 0.5f * t[5 + offset];
float tmp2 = 2.25f * t[2 + offset] - 3.25f * t[4 + offset];
m[8 + l] = tmp1 + tmp2 - 1.625f * t[3 + offset] + t[6 + offset];
m[16 + l] = tmp2 - tmp1 + 1.625f * t[3 + offset] + t[6 + offset];
tmp1 = 0.5625f * t[1 + offset] + t[5 + offset];
tmp2 = 0.5625f * t[2 + offset] - 2.5f * t[4 + offset];
m[24 + l] = tmp1 + tmp2 - 2.5f * t[3 + offset] + t[6 + offset];
m[32 + l] = tmp2 - tmp1 + 2.5f * t[3 + offset] + t[6 + offset];
tmp1 = 0.375f * t[1 + offset] + 1.5f * t[5 + offset];
tmp2 = 0.25f * t[2 + offset] - 1.25f * t[4 + offset];
m[40 + l] = tmp1 + tmp2 - 1.875f * t[3 + offset] + t[6 + offset];
m[48 + l] = tmp2 - tmp1 + 1.875f * t[3 + offset] + t[6 + offset];
m[56 + l] = -0.5625f * t[1 + offset] + 3.0625f * t[3 + offset] - 3.5f * t[5 + offset] + t[7 + offset];
}
for (int k = 0; k < 64; ++k) {
dst_data[i + k * dst_step] = m[k];
}
}
#ifdef ENABLE_ARM
}
#endif
}

View File

@ -28,7 +28,7 @@
#ifdef __cplusplus
extern "C" {
#endif
typedef void (*InputTransFunc)(const float *src_data, float *dst_data, int src_step, int dst_step);
typedef void (*InputTransFunc)(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
typedef void (*OutputTransFunc)(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step, int out_c, int r_w, int r_h, int r_c);
@ -163,11 +163,11 @@ void GeneralOutputTransformUnit(const float *src_data, float *dst_data, const fl
InputTransFunc GetInputTransFunc(int input_unit);
void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step);
void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, int dst_step);
void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step);
void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
OutputTransFunc GetOutputTransFunc(int input_unit, int output_unit, ActType act_type);

View File

@ -39,21 +39,18 @@ int ConvolutionWinogradCPUKernel::WinogradFilterTransform(const float *weight_da
// original weight format : ohwi
auto channel_in = conv_param_->input_channel_;
auto channel_out = conv_param_->output_channel_;
int ic4 = UP_DIV(channel_in, C4NUM);
int oc_block_num = UP_DIV(channel_out, oc_block);
int c4_channel = ic4 * C4NUM;
int block_stride = c4_channel * oc_block;
int block_stride = channel_in * oc_block;
int block_num_stride = block_stride * oc_block_num;
// trans_filter = G*g*GT (g represents weight_data)
// separate into two steps ===> tmp = (g * GT)T ===> trans = (tmp * GT)T use same function:MatrixMultiplyWinograd
auto tmp_data = reinterpret_cast<float *>(malloc(c4_channel * input_unit_ * kernel_unit_ * sizeof(float)));
auto tmp_data = reinterpret_cast<float *>(malloc(channel_in * input_unit_ * kernel_unit_ * sizeof(float)));
if (tmp_data == nullptr) {
MS_LOG(ERROR) << "malloc tmp_data failed.";
return RET_MEMORY_FAILED;
}
memset(tmp_data, 0, c4_channel * input_unit_ * kernel_unit_ * sizeof(float));
auto trans_out_data = reinterpret_cast<float *>(malloc(c4_channel * input_unit_ * input_unit_ * sizeof(float)));
auto trans_out_data = reinterpret_cast<float *>(malloc(channel_in * input_unit_ * input_unit_ * sizeof(float)));
if (trans_out_data == nullptr) {
free(tmp_data);
MS_LOG(ERROR) << "malloc trans_out_data failed.";
@ -61,14 +58,14 @@ int ConvolutionWinogradCPUKernel::WinogradFilterTransform(const float *weight_da
}
#ifndef ENABLE_ARM64
auto tmp_data1 = reinterpret_cast<float *>(malloc(c4_channel * input_unit_ * kernel_unit_ * sizeof(float)));
auto tmp_data1 = reinterpret_cast<float *>(malloc(channel_in * input_unit_ * kernel_unit_ * sizeof(float)));
if (tmp_data1 == nullptr) {
free(tmp_data);
free(trans_out_data);
MS_LOG(ERROR) << "malloc tmp_data1 failed.";
return RET_MEMORY_FAILED;
}
auto trans_out_data1 = reinterpret_cast<float *>(malloc(c4_channel * input_unit_ * input_unit_ * sizeof(float)));
auto trans_out_data1 = reinterpret_cast<float *>(malloc(channel_in * input_unit_ * input_unit_ * sizeof(float)));
if (trans_out_data1 == nullptr) {
free(tmp_data);
free(tmp_data1);
@ -87,30 +84,30 @@ int ConvolutionWinogradCPUKernel::WinogradFilterTransform(const float *weight_da
#ifndef ENABLE_ARM64
// tmp_data = g * GT
MatrixMultiplyWinograd(weight_data + i * input_oz_offset, matrix_gt, tmp_data, kernel_unit_, kernel_unit_,
input_unit_, channel_in, c4_channel * 4);
input_unit_, channel_in, channel_in * 4);
// tmp_data1 = (tmp_data)T
PackHWCToWHC(tmp_data, tmp_data1, kernel_unit_, input_unit_, c4_channel);
PackHWCToWHC(tmp_data, tmp_data1, kernel_unit_, input_unit_, channel_in);
// trans_out_data1 = tmp * GT
MatrixMultiplyWinograd(tmp_data1, matrix_gt, trans_out_data1, input_unit_, kernel_unit_, input_unit_, c4_channel,
c4_channel * 4);
MatrixMultiplyWinograd(tmp_data1, matrix_gt, trans_out_data1, input_unit_, kernel_unit_, input_unit_, channel_in,
channel_in * 4);
// trans_out_data = (trans_out_data1)T
PackHWCToWHC(trans_out_data1, trans_out_data, input_unit_, input_unit_, c4_channel);
PackHWCToWHC(trans_out_data1, trans_out_data, input_unit_, input_unit_, channel_in);
#else
// tmp = (g * GT)T
MatrixMultiplyWinograd(weight_data + i * input_oz_offset, matrix_gt, tmp_data, kernel_unit_, kernel_unit_,
input_unit_, channel_in, c4_channel * 4);
input_unit_, channel_in, channel_in * 4);
// trans = (tmp * GT)T
MatrixMultiplyWinograd(tmp_data, matrix_gt, trans_out_data, input_unit_, kernel_unit_, input_unit_, c4_channel,
c4_channel * 4);
MatrixMultiplyWinograd(tmp_data, matrix_gt, trans_out_data, input_unit_, kernel_unit_, input_unit_, channel_in,
channel_in * 4);
#endif
int in_offset = 0;
for (int j = 0; j < input_unit_; ++j) {
for (int k = 0; k < input_unit_; ++k) {
for (int c = 0; c < c4_channel; ++c) {
for (int c = 0; c < channel_in; ++c) {
*(trans_weight_ + output_oz_offset + c * oc_block) = trans_out_data[in_offset + c];
}
in_offset += c4_channel;
in_offset += channel_in;
output_oz_offset += block_num_stride;
}
}
@ -128,7 +125,6 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() {
auto filter_tensor = in_tensors_.at(kWeightIndex);
int in_channel = filter_tensor->Channel();
int out_channel = filter_tensor->Batch();
int ic4 = UP_DIV(in_channel, C4NUM);
conv_param_->input_channel_ = in_channel;
conv_param_->output_channel_ = out_channel;
@ -137,7 +133,7 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() {
int oc_block_num = UP_DIV(out_channel, C8NUM);
// set data
auto trans_matrix_data_size = input_unit_ * input_unit_ * ic4 * C4NUM * oc_block_num * oc_block * sizeof(float);
auto trans_matrix_data_size = input_unit_ * input_unit_ * in_channel * oc_block_num * oc_block * sizeof(float);
trans_weight_ = reinterpret_cast<float *>(malloc(trans_matrix_data_size));
if (trans_weight_ == nullptr) {
MS_LOG(ERROR) << "malloc matrix_buffer failed.";
@ -188,7 +184,6 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() {
int ConvolutionWinogradCPUKernel::InitTmpBuffer() {
int channel_out = conv_param_->output_channel_;
int oc8 = UP_DIV(channel_out, C8NUM);
int ic4 = UP_DIV(conv_param_->input_channel_, C4NUM);
#ifdef ENABLE_ARM32
int tile_num = 4;
#else
@ -196,7 +191,8 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() {
#endif
MS_ASSERT(ctx_->allocator != nullptr);
size_t tile_buffer_size = thread_count_ * tile_num * input_unit_ * input_unit_ * ic4 * C4NUM * sizeof(float);
size_t tile_buffer_size =
thread_count_ * tile_num * input_unit_ * input_unit_ * conv_param_->input_channel_ * sizeof(float);
trans_input_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(tile_buffer_size));
if (trans_input_ == nullptr) {
MS_LOG(ERROR) << "malloc trans_input_ failed.";
@ -217,8 +213,8 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() {
return RET_MEMORY_FAILED;
}
col_buffer_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(thread_count_ * tile_num * ic4 * C4NUM * sizeof(float)));
col_buffer_ = reinterpret_cast<float *>(
ctx_->allocator->Malloc(thread_count_ * tile_num * conv_param_->input_channel_ * sizeof(float)));
if (col_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc col_buffer_ failed.";
return RET_ERROR;