forked from mindspore-Ecosystem/mindspore
optimize winograd input transform func
This commit is contained in:
parent
8b4cdc1523
commit
2d00b74de2
|
@ -77,7 +77,6 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
|
|||
int input_unit = conv_param->input_unit_;
|
||||
int in_batch = conv_param->input_batch_;
|
||||
int in_channel = conv_param->input_channel_;
|
||||
int ic4 = UP_DIV(in_channel, C4NUM);
|
||||
int out_unit = conv_param->output_unit_;
|
||||
int out_w_block = UP_DIV(conv_param->output_w_, out_unit);
|
||||
int out_h_block = UP_DIV(conv_param->output_h_, out_unit);
|
||||
|
@ -96,10 +95,10 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
|
|||
float *gemm_out = buffer_list[1];
|
||||
float *tmp_data = buffer_list[2];
|
||||
float *col_buffer = buffer_list[3];
|
||||
int trans_input_offset = tile_num * input_unit_square * ic4 * C4NUM;
|
||||
int trans_input_offset = tile_num * input_unit_square * in_channel;
|
||||
int gemm_out_offset = tile_num * input_unit_square * oc8 * C8NUM;
|
||||
int tmp_data_offset = input_unit_square * C4NUM;
|
||||
int col_buffer_offset = tile_num * ic4 * C4NUM;
|
||||
int col_buffer_offset = tile_num * in_channel;
|
||||
// step 1 : filter transform (pre-processed offline)
|
||||
// step 2 : input transform (online)
|
||||
for (int b = 0; b < in_batch; b++) {
|
||||
|
@ -107,7 +106,7 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
|
|||
int out_batch_offset = b * out_channel * conv_param->output_w_ * conv_param->output_h_;
|
||||
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) {
|
||||
int out_tile_index = thread_id * tile_num;
|
||||
int cal_num = output_count - thread_id * tile_num;
|
||||
int cal_num = output_count - out_tile_index;
|
||||
cal_num = cal_num > tile_num ? tile_num : cal_num;
|
||||
WinogradInputTransform(input_data + in_batch_offset, trans_input + task_id * trans_input_offset,
|
||||
tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param,
|
||||
|
@ -118,11 +117,11 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
|
|||
float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset;
|
||||
for (int i = 0; i < input_unit_square; ++i) {
|
||||
#ifdef ENABLE_ARM32
|
||||
RowMajor2Col4Major(src_ptr + i * C4NUM * ic4 * C4NUM, tmp_col_ptr, C4NUM, ic4 * C4NUM);
|
||||
RowMajor2Col4Major(src_ptr + i * C4NUM * in_channel, tmp_col_ptr, C4NUM, in_channel);
|
||||
#else
|
||||
RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM);
|
||||
RowMajor2Col12Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel);
|
||||
#endif
|
||||
MatMulOpt(tmp_col_ptr, trans_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, ic4 * C4NUM,
|
||||
MatMulOpt(tmp_col_ptr, trans_weight + i * in_channel * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, in_channel,
|
||||
cal_num, oc8 * C8NUM, input_unit_square, 2);
|
||||
}
|
||||
|
||||
|
|
|
@ -630,26 +630,6 @@ void PackNHWC4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int c
|
|||
}
|
||||
}
|
||||
|
||||
void PackNCHWToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
int nhwc4_batch_offset = 0;
|
||||
int c4 = UP_DIV(channel, C4NUM);
|
||||
int nhwc4_batch_unit_offset = c4 * C4NUM * plane;
|
||||
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int batch_offset = b * channel * plane;
|
||||
for (int c = 0; c < channel; c++) {
|
||||
int src_c_offset = batch_offset + c * plane;
|
||||
int dst_c_offset = nhwc4_batch_offset + c;
|
||||
for (int i = 0; i < plane; i++) {
|
||||
int src_plane_offset = src_c_offset + i;
|
||||
int dst_plane_offset = dst_c_offset + i * c4 * C4NUM;
|
||||
((float *)dst)[dst_plane_offset] = ((float *)src)[src_plane_offset];
|
||||
}
|
||||
}
|
||||
nhwc4_batch_offset += nhwc4_batch_unit_offset;
|
||||
}
|
||||
}
|
||||
|
||||
void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
int c4 = UP_DIV(channel, C4NUM);
|
||||
for (int b = 0; b < batch; b++) {
|
||||
|
@ -700,105 +680,6 @@ void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int
|
|||
}
|
||||
}
|
||||
|
||||
void PackNC4HW4ToNHWCReluFp32(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
int c4 = UP_DIV(channel, C4NUM);
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int src_offset = b * plane * c4 * C4NUM;
|
||||
int dst_offset = b * plane * channel;
|
||||
for (int k = 0; k < plane; k++) {
|
||||
int src_kernel_offset = src_offset + k * C4NUM;
|
||||
int dst_kernel_offset = dst_offset + k * channel;
|
||||
for (int c = 0; c < c4 - 1; c++) {
|
||||
int src_c_offset = src_kernel_offset + c * plane * C4NUM;
|
||||
int dst_c_offset = dst_kernel_offset + c * C4NUM;
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t input_ptr = vld1q_f32((float *)src + src_c_offset);
|
||||
float32x4_t zero = vdupq_n_f32(0);
|
||||
input_ptr = vmaxq_f32(zero, input_ptr);
|
||||
vst1q_f32((float *)dst + dst_c_offset, input_ptr);
|
||||
#else
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
float input_data = ((float *)src + src_c_offset)[i];
|
||||
input_data = input_data < 0 ? 0 : input_data;
|
||||
((float *)dst + dst_c_offset)[i] = input_data;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
// res part
|
||||
int res_c = channel - (c4 - 1) * C4NUM;
|
||||
for (int i = 0; i < res_c; i++) {
|
||||
int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i;
|
||||
int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i;
|
||||
float input_data = ((float *)src + src_res_c_offset)[0];
|
||||
input_data = input_data < 0 ? 0 : input_data;
|
||||
((float *)dst + dst_res_c_offset)[0] = input_data;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PackNC4HW4ToNHWCRelu6Fp32(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
int c4 = UP_DIV(channel, C4NUM);
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int src_offset = b * plane * c4 * C4NUM;
|
||||
int dst_offset = b * plane * channel;
|
||||
for (int k = 0; k < plane; k++) {
|
||||
int src_kernel_offset = src_offset + k * C4NUM;
|
||||
int dst_kernel_offset = dst_offset + k * channel;
|
||||
for (int c = 0; c < c4 - 1; c++) {
|
||||
int src_c_offset = src_kernel_offset + c * plane * C4NUM;
|
||||
int dst_c_offset = dst_kernel_offset + c * C4NUM;
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t input_ptr = vld1q_f32((float *)src + src_c_offset);
|
||||
float32x4_t zero = vdupq_n_f32(0);
|
||||
float32x4_t six = vdupq_n_f32(6);
|
||||
input_ptr = vmaxq_f32(zero, input_ptr);
|
||||
input_ptr = vminq_f32(six, input_ptr);
|
||||
vst1q_f32((float *)dst + dst_c_offset, input_ptr);
|
||||
#else
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
float input_data = ((float *)src + src_c_offset)[i];
|
||||
input_data = input_data < 0 ? 0 : input_data;
|
||||
input_data = input_data > 6 ? 6 : input_data;
|
||||
((float *)dst + dst_c_offset)[i] = input_data;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
// res part
|
||||
int res_c = channel - (c4 - 1) * C4NUM;
|
||||
for (int i = 0; i < res_c; i++) {
|
||||
int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i;
|
||||
int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i;
|
||||
float input_data = ((float *)src + src_res_c_offset)[0];
|
||||
input_data = input_data < 0 ? 0 : input_data;
|
||||
input_data = input_data > 6 ? 6 : input_data;
|
||||
((float *)dst + dst_res_c_offset)[0] = input_data;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PackNC4HW4ToNHWCPreluFp32(const void *src, void *dst, const void *slope, int batch, int plane, int channel) {}
|
||||
|
||||
void PackNC4HW4ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
int c4 = UP_DIV(channel, C4NUM);
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int src_offset = b * plane * c4 * C4NUM;
|
||||
int dst_offset = b * plane * channel;
|
||||
for (int c = 0; c < channel; c++) {
|
||||
int c4_block_num = c / C4NUM;
|
||||
int c4_block_res = c % C4NUM;
|
||||
int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res;
|
||||
int dst_c_offset = dst_offset + c * plane;
|
||||
for (int k = 0; k < plane; k++) {
|
||||
int src_kernel_offset = src_c_offset + k * C4NUM;
|
||||
int dst_kernel_offset = dst_c_offset + k;
|
||||
((float *)dst + dst_kernel_offset)[0] = ((float *)src + src_kernel_offset)[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
for (int n = 0; n < batch; n++) {
|
||||
for (int hw = 0; hw < plane; hw++) {
|
||||
|
@ -896,45 +777,6 @@ void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int c
|
|||
}
|
||||
}
|
||||
|
||||
void PackNCHWToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
int nhwc4_batch_offset = 0;
|
||||
int c4 = UP_DIV(channel, C4NUM);
|
||||
int nhwc4_batch_unit_offset = c4 * C4NUM * plane;
|
||||
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int batch_offset = b * channel * plane;
|
||||
for (int c = 0; c < channel; c++) {
|
||||
int src_c_offset = batch_offset + c * plane;
|
||||
int dst_c_offset = nhwc4_batch_offset + c;
|
||||
for (int i = 0; i < plane; i++) {
|
||||
int src_plane_offset = src_c_offset + i;
|
||||
int dst_plane_offset = dst_c_offset + i * c4 * C4NUM;
|
||||
((uint8_t *)dst)[dst_plane_offset] = ((uint8_t *)src)[src_plane_offset];
|
||||
}
|
||||
}
|
||||
nhwc4_batch_offset += nhwc4_batch_unit_offset;
|
||||
}
|
||||
}
|
||||
|
||||
void PackNC4HW4ToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
int c4 = UP_DIV(channel, C4NUM);
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int src_offset = b * plane * c4 * C4NUM;
|
||||
int dst_offset = b * plane * channel;
|
||||
for (int c = 0; c < channel; c++) {
|
||||
int c4_block_num = c / C4NUM;
|
||||
int c4_block_res = c % C4NUM;
|
||||
int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res;
|
||||
int dst_c_offset = dst_offset + c4_block_num * C4NUM + c4_block_res;
|
||||
for (int k = 0; k < plane; k++) {
|
||||
int src_kernel_offset = src_c_offset + k * C4NUM;
|
||||
int dst_kernel_offset = dst_c_offset + k * c4 * C4NUM;
|
||||
((uint8_t *)dst + dst_kernel_offset)[0] = ((uint8_t *)src + src_kernel_offset)[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
int c4 = UP_DIV(channel, C4NUM);
|
||||
for (int b = 0; b < batch; b++) {
|
||||
|
@ -962,25 +804,6 @@ void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int
|
|||
}
|
||||
}
|
||||
|
||||
void PackNC4HW4ToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
int c4 = UP_DIV(channel, C4NUM);
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int src_offset = b * plane * c4 * C4NUM;
|
||||
int dst_offset = b * plane * channel;
|
||||
for (int c = 0; c < channel; c++) {
|
||||
int c4_block_num = c / C4NUM;
|
||||
int c4_block_res = c % C4NUM;
|
||||
int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res;
|
||||
int dst_c_offset = dst_offset + c * plane;
|
||||
for (int k = 0; k < plane; k++) {
|
||||
int src_kernel_offset = src_c_offset + k * C4NUM;
|
||||
int dst_kernel_offset = dst_c_offset + k;
|
||||
((uint8_t *)dst + dst_kernel_offset)[0] = ((uint8_t *)src + src_kernel_offset)[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PackNHWCToC8HWN8Int8(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
for (int n = 0; n < batch; n++) {
|
||||
for (int hw = 0; hw < plane; hw++) {
|
||||
|
@ -996,25 +819,6 @@ void PackNHWCToC8HWN8Int8(const void *src, void *dst, int batch, int plane, int
|
|||
return;
|
||||
}
|
||||
|
||||
void PackNHWCToNC8HW8Int8(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
int c8 = UP_DIV(channel, C8NUM);
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int src_oc_offset = b * plane * channel;
|
||||
int dst_oc_offset = b * plane * c8 * C8NUM;
|
||||
for (int k = 0; k < plane; k++) {
|
||||
int src_kernel_offset = src_oc_offset + k * channel;
|
||||
int dst_kernel_offset = dst_oc_offset + k * C8NUM;
|
||||
for (int i = 0; i < channel; i++) {
|
||||
int c8_block_num = i / C8NUM;
|
||||
int c8_block_rem = i % C8NUM;
|
||||
int src_ic_offset = src_kernel_offset + i;
|
||||
int dst_ic_offset = dst_kernel_offset + c8_block_num * plane * C8NUM + c8_block_rem;
|
||||
((int8_t *)dst + dst_ic_offset)[0] = ((int8_t *)src + src_ic_offset)[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
for (int n = 0; n < batch; n++) {
|
||||
for (int c = 0; c < channel; c++) {
|
||||
|
@ -1231,27 +1035,6 @@ void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int ch
|
|||
return PackNHWCToNCHWFp32(src, dst, batch, channel, plane);
|
||||
}
|
||||
|
||||
void MatrixPackUnit(const float *src, float *dst, size_t row, size_t col, size_t src_stride, size_t dst_stride) {
|
||||
size_t copy_size = row * C4NUM * sizeof(float);
|
||||
for (int c = 0; c < col; c++) {
|
||||
memcpy(dst + c * dst_stride, src + c * src_stride, copy_size);
|
||||
}
|
||||
}
|
||||
|
||||
void MatrixPack(const float *src, float *dst, int row, int ic4, int stride) {
|
||||
int row4mod = row % 4;
|
||||
int row4div = row / 4;
|
||||
|
||||
for (int i = 0; i < row4div; i++) {
|
||||
MatrixPackUnit(src + i * 4 * 4, dst + i * 4 * ic4 * 4, 4, ic4, stride, 16);
|
||||
}
|
||||
|
||||
if (row4mod > 0) {
|
||||
MatrixPackUnit(src + row4div * 4 * 4, dst + row4div * 4 * ic4 * 4, row4mod, ic4, stride, row4mod * 4);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter *conv_param) {
|
||||
int input_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_;
|
||||
int ic4 = UP_DIV(conv_param->input_channel_, C4NUM);
|
||||
|
|
|
@ -46,11 +46,6 @@ void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParam
|
|||
|
||||
void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, int32_t *filter_zp, ConvParameter *conv_param);
|
||||
|
||||
void PackInputSum8x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel,
|
||||
size_t plane_size, ConvParameter *conv_param);
|
||||
|
||||
void MatrixPack(const float *src, float *dst, int row, int ic4, int stride);
|
||||
|
||||
void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, ConvParameter *conv_param);
|
||||
|
||||
void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel);
|
||||
|
@ -75,20 +70,10 @@ void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int ch
|
|||
|
||||
void PackNHWC4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNCHWToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNC4HW4ToNHWCReluFp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNC4HW4ToNHWCRelu6Fp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNC4HW4ToNHWCPreluFp32(const void *src, void *dst, const void *slope, int batch, int plane, int channel);
|
||||
|
||||
void PackNC4HW4ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
@ -99,18 +84,10 @@ void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int c
|
|||
|
||||
void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNCHWToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNC4HW4ToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNC4HW4ToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNHWCToC8HWN8Int8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNHWCToNC8HW8Int8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter *conv_param);
|
||||
|
|
|
@ -42,7 +42,7 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
|
|||
int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s);
|
||||
|
||||
int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s);
|
||||
int dst_plane_offset = c * C4NUM * ic4;
|
||||
int dst_plane_offset = c * in_channel;
|
||||
for (int ic = 0; ic < ic4; ic++) {
|
||||
// clear tmp buffer
|
||||
memset(tmp_data, 0, input_unit * input_unit * C4NUM * sizeof(float));
|
||||
|
@ -91,9 +91,9 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
|
|||
const int tile_num = 12;
|
||||
#endif
|
||||
int dst_ic4_offset = dst_plane_offset + ic * C4NUM;
|
||||
size_t dst_step = tile_num * ic4 * C4NUM;
|
||||
size_t dst_step = tile_num * in_channel;
|
||||
float *trans_input_ptr = trans_input + dst_ic4_offset;
|
||||
func(tmp_data, trans_input_ptr, C4NUM, dst_step);
|
||||
func(tmp_data, trans_input_ptr, C4NUM, dst_step, real_c);
|
||||
}
|
||||
out_tile_index++;
|
||||
} // cal_tile_num loop
|
||||
|
|
|
@ -171,227 +171,241 @@ void GeneralOutputTransformUnit(const float *src_data, float *dst_data, const fl
|
|||
|
||||
InputTransFunc GetInputTransFunc(int input_unit) { return InputTransFuncList[input_unit]; }
|
||||
|
||||
void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step) {
|
||||
void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) {
|
||||
#ifdef ENABLE_ARM
|
||||
float32x4_t src[16];
|
||||
float32x4_t t[16];
|
||||
float32x4_t m[16];
|
||||
Load16Data;
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
int offset = l * 4;
|
||||
t[l] = vsubq_f32(src[offset], src[2 + offset]);
|
||||
t[4 + l] = vaddq_f32(src[1 + offset], src[2 + offset]);
|
||||
t[8 + l] = vsubq_f32(src[2 + offset], src[1 + offset]);
|
||||
t[12 + l] = vsubq_f32(src[3 + offset], src[1 + offset]);
|
||||
}
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
int offset = l * 4;
|
||||
m[l] = vsubq_f32(t[offset], t[2 + offset]);
|
||||
m[4 + l] = vaddq_f32(t[1 + offset], t[2 + offset]);
|
||||
m[8 + l] = vsubq_f32(t[2 + offset], t[1 + offset]);
|
||||
m[12 + l] = vsubq_f32(t[3 + offset], t[1 + offset]);
|
||||
}
|
||||
for (int i = 0; i < 16; i++) {
|
||||
vst1q_f32(dst_data + i * dst_step, m[i]);
|
||||
}
|
||||
#else
|
||||
float src[16];
|
||||
float t[16];
|
||||
float m[16];
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
src[j] = src_data[i + j * src_step];
|
||||
if (real_c == 4) {
|
||||
float32x4_t src[16];
|
||||
float32x4_t t[16];
|
||||
float32x4_t m[16];
|
||||
Load16Data;
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
int offset = l * 4;
|
||||
t[l] = vsubq_f32(src[offset], src[2 + offset]);
|
||||
t[4 + l] = vaddq_f32(src[1 + offset], src[2 + offset]);
|
||||
t[8 + l] = vsubq_f32(src[2 + offset], src[1 + offset]);
|
||||
t[12 + l] = vsubq_f32(src[3 + offset], src[1 + offset]);
|
||||
}
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
int offset = l * 4;
|
||||
t[l] = src[offset] - src[2 + offset];
|
||||
t[4 + l] = src[1 + offset] + src[2 + offset];
|
||||
t[8 + l] = src[2 + offset] - src[1 + offset];
|
||||
t[12 + l] = src[3 + offset] - src[1 + offset];
|
||||
m[l] = vsubq_f32(t[offset], t[2 + offset]);
|
||||
m[4 + l] = vaddq_f32(t[1 + offset], t[2 + offset]);
|
||||
m[8 + l] = vsubq_f32(t[2 + offset], t[1 + offset]);
|
||||
m[12 + l] = vsubq_f32(t[3 + offset], t[1 + offset]);
|
||||
}
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
int offset = l * 4;
|
||||
m[l] = t[offset] - t[2 + offset];
|
||||
m[4 + l] = t[1 + offset] + t[2 + offset];
|
||||
m[8 + l] = t[2 + offset] - t[1 + offset];
|
||||
m[12 + l] = t[3 + offset] - t[1 + offset];
|
||||
for (int i = 0; i < 16; i++) {
|
||||
vst1q_f32(dst_data + i * dst_step, m[i]);
|
||||
}
|
||||
for (int k = 0; k < 16; ++k) {
|
||||
dst_data[i + k * dst_step] = m[k];
|
||||
} else {
|
||||
#endif
|
||||
float src[16];
|
||||
float t[16];
|
||||
float m[16];
|
||||
for (int i = 0; i < real_c; ++i) {
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
src[j] = src_data[i + j * src_step];
|
||||
}
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
int offset = l * 4;
|
||||
t[l] = src[offset] - src[2 + offset];
|
||||
t[4 + l] = src[1 + offset] + src[2 + offset];
|
||||
t[8 + l] = src[2 + offset] - src[1 + offset];
|
||||
t[12 + l] = src[3 + offset] - src[1 + offset];
|
||||
}
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
int offset = l * 4;
|
||||
m[l] = t[offset] - t[2 + offset];
|
||||
m[4 + l] = t[1 + offset] + t[2 + offset];
|
||||
m[8 + l] = t[2 + offset] - t[1 + offset];
|
||||
m[12 + l] = t[3 + offset] - t[1 + offset];
|
||||
}
|
||||
for (int k = 0; k < 16; ++k) {
|
||||
dst_data[i + k * dst_step] = m[k];
|
||||
}
|
||||
}
|
||||
#ifdef ENABLE_ARM
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, int dst_step) {
|
||||
void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) {
|
||||
#ifdef ENABLE_ARM
|
||||
float32x4_t src[36];
|
||||
float32x4_t t[36];
|
||||
float32x4_t m[36];
|
||||
Load36Data;
|
||||
for (int l = 0; l < 6; ++l) {
|
||||
int offset = l * 6;
|
||||
float32x4_t tmp1 = vsubq_f32(src[3 + offset], src[1 + offset]);
|
||||
float32x4_t tmp2 = vsubq_f32(src[4 + offset], src[2 + offset]);
|
||||
t[l] = vaddq_f32(vsubq_f32(vmulq_n_f32(src[offset], 4), vmulq_n_f32(src[2 + offset], 5)), src[4 + offset]);
|
||||
t[6 + l] = vaddq_f32(vmulq_n_f32(vaddq_f32(src[1 + offset], src[2 + offset]), -4),
|
||||
vaddq_f32(src[3 + offset], src[4 + offset]));
|
||||
t[12 + l] = vaddq_f32(vmulq_n_f32(vsubq_f32(src[1 + offset], src[2 + offset]), 4),
|
||||
vsubq_f32(src[4 + offset], src[3 + offset]));
|
||||
t[18 + l] = vaddq_f32(vmulq_n_f32(tmp1, 2), tmp2);
|
||||
t[24 + l] = vaddq_f32(vmulq_n_f32(tmp1, -2), tmp2);
|
||||
t[30 + l] = vaddq_f32(vsubq_f32(vmulq_n_f32(src[1 + offset], 4), vmulq_n_f32(src[3 + offset], 5)), src[5 + offset]);
|
||||
}
|
||||
for (int l = 0; l < 6; ++l) {
|
||||
int offset = l * 6;
|
||||
float32x4_t tmp1 = vsubq_f32(t[3 + offset], t[1 + offset]);
|
||||
float32x4_t tmp2 = vsubq_f32(t[4 + offset], t[2 + offset]);
|
||||
m[l] = vaddq_f32(vsubq_f32(vmulq_n_f32(t[offset], 4), vmulq_n_f32(t[2 + offset], 5)), t[4 + offset]);
|
||||
m[6 + l] =
|
||||
vaddq_f32(vmulq_n_f32(vaddq_f32(t[1 + offset], t[2 + offset]), -4), vaddq_f32(t[3 + offset], t[4 + offset]));
|
||||
m[12 + l] =
|
||||
vaddq_f32(vmulq_n_f32(vsubq_f32(t[1 + offset], t[2 + offset]), 4), vsubq_f32(t[4 + offset], t[3 + offset]));
|
||||
m[18 + l] = vaddq_f32(vmulq_n_f32(tmp1, 2), tmp2);
|
||||
m[24 + l] = vaddq_f32(vmulq_n_f32(tmp1, -2), tmp2);
|
||||
m[30 + l] = vaddq_f32(vsubq_f32(vmulq_n_f32(t[1 + offset], 4), vmulq_n_f32(t[3 + offset], 5)), t[5 + offset]);
|
||||
}
|
||||
for (int i = 0; i < 36; i++) {
|
||||
vst1q_f32(dst_data + i * dst_step, m[i]);
|
||||
}
|
||||
#else
|
||||
float src[36];
|
||||
float t[36];
|
||||
float m[36];
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
for (int j = 0; j < 36; ++j) {
|
||||
src[j] = src_data[i + j * src_step];
|
||||
if (real_c == 4) {
|
||||
float32x4_t src[36];
|
||||
float32x4_t t[36];
|
||||
float32x4_t m[36];
|
||||
Load36Data;
|
||||
for (int l = 0; l < 6; ++l) {
|
||||
int offset = l * 6;
|
||||
float32x4_t tmp1 = vsubq_f32(src[3 + offset], src[1 + offset]);
|
||||
float32x4_t tmp2 = vsubq_f32(src[4 + offset], src[2 + offset]);
|
||||
t[l] = vaddq_f32(vsubq_f32(vmulq_n_f32(src[offset], 4), vmulq_n_f32(src[2 + offset], 5)), src[4 + offset]);
|
||||
t[6 + l] = vaddq_f32(vmulq_n_f32(vaddq_f32(src[1 + offset], src[2 + offset]), -4),
|
||||
vaddq_f32(src[3 + offset], src[4 + offset]));
|
||||
t[12 + l] = vaddq_f32(vmulq_n_f32(vsubq_f32(src[1 + offset], src[2 + offset]), 4),
|
||||
vsubq_f32(src[4 + offset], src[3 + offset]));
|
||||
t[18 + l] = vaddq_f32(vmulq_n_f32(tmp1, 2), tmp2);
|
||||
t[24 + l] = vaddq_f32(vmulq_n_f32(tmp1, -2), tmp2);
|
||||
t[30 + l] =
|
||||
vaddq_f32(vsubq_f32(vmulq_n_f32(src[1 + offset], 4), vmulq_n_f32(src[3 + offset], 5)), src[5 + offset]);
|
||||
}
|
||||
for (int l = 0; l < 6; ++l) {
|
||||
int offset = l * 6;
|
||||
float tmp1 = src[3 + offset] - src[1 + offset];
|
||||
float tmp2 = src[4 + offset] - src[2 + offset];
|
||||
t[l] = 4 * src[offset] - 5 * src[2 + offset] + src[4 + offset];
|
||||
t[6 + l] = -4 * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]);
|
||||
t[12 + l] = 4 * (src[1 + offset] - src[2 + offset]) + (src[4 + offset] - src[3 + offset]);
|
||||
t[18 + l] = 2 * tmp1 + tmp2;
|
||||
t[24 + l] = -2 * tmp1 + tmp2;
|
||||
t[30 + l] = 4 * src[1 + offset] - 5 * src[3 + offset] + src[5 + offset];
|
||||
float32x4_t tmp1 = vsubq_f32(t[3 + offset], t[1 + offset]);
|
||||
float32x4_t tmp2 = vsubq_f32(t[4 + offset], t[2 + offset]);
|
||||
m[l] = vaddq_f32(vsubq_f32(vmulq_n_f32(t[offset], 4), vmulq_n_f32(t[2 + offset], 5)), t[4 + offset]);
|
||||
m[6 + l] =
|
||||
vaddq_f32(vmulq_n_f32(vaddq_f32(t[1 + offset], t[2 + offset]), -4), vaddq_f32(t[3 + offset], t[4 + offset]));
|
||||
m[12 + l] =
|
||||
vaddq_f32(vmulq_n_f32(vsubq_f32(t[1 + offset], t[2 + offset]), 4), vsubq_f32(t[4 + offset], t[3 + offset]));
|
||||
m[18 + l] = vaddq_f32(vmulq_n_f32(tmp1, 2), tmp2);
|
||||
m[24 + l] = vaddq_f32(vmulq_n_f32(tmp1, -2), tmp2);
|
||||
m[30 + l] = vaddq_f32(vsubq_f32(vmulq_n_f32(t[1 + offset], 4), vmulq_n_f32(t[3 + offset], 5)), t[5 + offset]);
|
||||
}
|
||||
for (int l = 0; l < 6; ++l) {
|
||||
int offset = l * 6;
|
||||
float tmp1 = t[3 + offset] - t[1 + offset];
|
||||
float tmp2 = t[4 + offset] - t[2 + offset];
|
||||
m[l] = 4 * t[offset] - 5 * t[2 + offset] + t[4 + offset];
|
||||
m[6 + l] = -4 * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]);
|
||||
m[12 + l] = 4 * (t[1 + offset] - t[2 + offset]) + (t[4 + offset] - t[3 + offset]);
|
||||
m[18 + l] = 2 * tmp1 + tmp2;
|
||||
m[24 + l] = -2 * tmp1 + tmp2;
|
||||
m[30 + l] = 4 * t[1 + offset] - 5 * t[3 + offset] + t[5 + offset];
|
||||
for (int i = 0; i < 36; i++) {
|
||||
vst1q_f32(dst_data + i * dst_step, m[i]);
|
||||
}
|
||||
for (int k = 0; k < 36; ++k) {
|
||||
dst_data[i + k * dst_step] = m[k];
|
||||
} else {
|
||||
#endif
|
||||
float src[36];
|
||||
float t[36];
|
||||
float m[36];
|
||||
for (int i = 0; i < real_c; ++i) {
|
||||
for (int j = 0; j < 36; ++j) {
|
||||
src[j] = src_data[i + j * src_step];
|
||||
}
|
||||
for (int l = 0; l < 6; ++l) {
|
||||
int offset = l * 6;
|
||||
float tmp1 = src[3 + offset] - src[1 + offset];
|
||||
float tmp2 = src[4 + offset] - src[2 + offset];
|
||||
t[l] = 4 * src[offset] - 5 * src[2 + offset] + src[4 + offset];
|
||||
t[6 + l] = -4 * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]);
|
||||
t[12 + l] = 4 * (src[1 + offset] - src[2 + offset]) + (src[4 + offset] - src[3 + offset]);
|
||||
t[18 + l] = 2 * tmp1 + tmp2;
|
||||
t[24 + l] = -2 * tmp1 + tmp2;
|
||||
t[30 + l] = 4 * src[1 + offset] - 5 * src[3 + offset] + src[5 + offset];
|
||||
}
|
||||
for (int l = 0; l < 6; ++l) {
|
||||
int offset = l * 6;
|
||||
float tmp1 = t[3 + offset] - t[1 + offset];
|
||||
float tmp2 = t[4 + offset] - t[2 + offset];
|
||||
m[l] = 4 * t[offset] - 5 * t[2 + offset] + t[4 + offset];
|
||||
m[6 + l] = -4 * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]);
|
||||
m[12 + l] = 4 * (t[1 + offset] - t[2 + offset]) + (t[4 + offset] - t[3 + offset]);
|
||||
m[18 + l] = 2 * tmp1 + tmp2;
|
||||
m[24 + l] = -2 * tmp1 + tmp2;
|
||||
m[30 + l] = 4 * t[1 + offset] - 5 * t[3 + offset] + t[5 + offset];
|
||||
}
|
||||
for (int k = 0; k < 36; ++k) {
|
||||
dst_data[i + k * dst_step] = m[k];
|
||||
}
|
||||
}
|
||||
#ifdef ENABLE_ARM
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step) {
|
||||
void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) {
|
||||
#ifdef ENABLE_ARM
|
||||
float32x4_t src[64];
|
||||
float32x4_t t[64];
|
||||
float32x4_t m[64];
|
||||
Load64Data;
|
||||
for (int l = 0; l < 8; ++l) {
|
||||
int offset = l * 8;
|
||||
t[l] = vsubq_f32(vaddq_f32(vsubq_f32(vmulq_n_f32(src[offset], 0.5625), vmulq_n_f32(src[2 + offset], 3.0625)),
|
||||
vmulq_n_f32(src[4 + offset], 3.5)),
|
||||
src[6 + offset]);
|
||||
float32x4_t tmp1 = vaddq_f32(vmulq_n_f32(src[1 + offset], 1.125), vmulq_n_f32(src[5 + offset], 0.5));
|
||||
float32x4_t tmp2 = vsubq_f32(vmulq_n_f32(src[2 + offset], 2.25), vmulq_n_f32(src[4 + offset], 3.25));
|
||||
t[8 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(src[3 + offset], 1.625)), src[6 + offset]);
|
||||
t[16 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(src[3 + offset], 1.625)), src[6 + offset]);
|
||||
tmp1 = vaddq_f32(vmulq_n_f32(src[1 + offset], 0.5625), src[5 + offset]);
|
||||
tmp2 = vsubq_f32(vmulq_n_f32(src[2 + offset], 0.5625), vmulq_n_f32(src[4 + offset], 2.5));
|
||||
t[24 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(src[3 + offset], 2.5)), src[6 + offset]);
|
||||
t[32 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(src[3 + offset], 2.5)), src[6 + offset]);
|
||||
tmp1 = vaddq_f32(vmulq_n_f32(src[1 + offset], 0.375), vmulq_n_f32(src[5 + offset], 1.5));
|
||||
tmp2 = vsubq_f32(vmulq_n_f32(src[2 + offset], 0.25), vmulq_n_f32(src[4 + offset], 1.25));
|
||||
t[40 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(src[3 + offset], 1.875)), src[6 + offset]);
|
||||
t[48 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(src[3 + offset], 1.875)), src[6 + offset]);
|
||||
t[56 + l] =
|
||||
vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src[1 + offset], -0.5625), vmulq_n_f32(src[3 + offset], 3.0625)),
|
||||
vmulq_n_f32(src[5 + offset], 3.5)),
|
||||
src[7 + offset]);
|
||||
}
|
||||
for (int l = 0; l < 8; ++l) {
|
||||
int offset = l * 8;
|
||||
m[l] = vsubq_f32(vaddq_f32(vsubq_f32(vmulq_n_f32(t[offset], 0.5625), vmulq_n_f32(t[2 + offset], 3.0625)),
|
||||
vmulq_n_f32(t[4 + offset], 3.5)),
|
||||
t[6 + offset]);
|
||||
float32x4_t tmp1 = vaddq_f32(vmulq_n_f32(t[1 + offset], 1.125), vmulq_n_f32(t[5 + offset], 0.5));
|
||||
float32x4_t tmp2 = vsubq_f32(vmulq_n_f32(t[2 + offset], 2.25), vmulq_n_f32(t[4 + offset], 3.25));
|
||||
m[8 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(t[3 + offset], 1.625)), t[6 + offset]);
|
||||
m[16 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(t[3 + offset], 1.625)), t[6 + offset]);
|
||||
tmp1 = vaddq_f32(vmulq_n_f32(t[1 + offset], 0.5625), t[5 + offset]);
|
||||
tmp2 = vsubq_f32(vmulq_n_f32(t[2 + offset], 0.5625), vmulq_n_f32(t[4 + offset], 2.5));
|
||||
m[24 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(t[3 + offset], 2.5)), t[6 + offset]);
|
||||
m[32 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(t[3 + offset], 2.5)), t[6 + offset]);
|
||||
tmp1 = vaddq_f32(vmulq_n_f32(t[1 + offset], 0.375), vmulq_n_f32(t[5 + offset], 1.5));
|
||||
tmp2 = vsubq_f32(vmulq_n_f32(t[2 + offset], 0.25), vmulq_n_f32(t[4 + offset], 1.25));
|
||||
m[40 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(t[3 + offset], 1.875)), t[6 + offset]);
|
||||
m[48 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(t[3 + offset], 1.875)), t[6 + offset]);
|
||||
m[56 + l] = vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t[1 + offset], -0.5625), vmulq_n_f32(t[3 + offset], 3.0625)),
|
||||
vmulq_n_f32(t[5 + offset], 3.5)),
|
||||
t[7 + offset]);
|
||||
}
|
||||
for (int i = 0; i < 64; i++) {
|
||||
vst1q_f32(dst_data + i * dst_step, m[i]);
|
||||
}
|
||||
#else
|
||||
float src[64];
|
||||
float t[64];
|
||||
float m[64];
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
for (int j = 0; j < 64; ++j) {
|
||||
src[j] = src_data[i + j * src_step];
|
||||
if (real_c == 4) {
|
||||
float32x4_t src[64];
|
||||
float32x4_t t[64];
|
||||
float32x4_t m[64];
|
||||
Load64Data;
|
||||
for (int l = 0; l < 8; ++l) {
|
||||
int offset = l * 8;
|
||||
t[l] = vsubq_f32(vaddq_f32(vsubq_f32(vmulq_n_f32(src[offset], 0.5625), vmulq_n_f32(src[2 + offset], 3.0625)),
|
||||
vmulq_n_f32(src[4 + offset], 3.5)),
|
||||
src[6 + offset]);
|
||||
float32x4_t tmp1 = vaddq_f32(vmulq_n_f32(src[1 + offset], 1.125), vmulq_n_f32(src[5 + offset], 0.5));
|
||||
float32x4_t tmp2 = vsubq_f32(vmulq_n_f32(src[2 + offset], 2.25), vmulq_n_f32(src[4 + offset], 3.25));
|
||||
t[8 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(src[3 + offset], 1.625)), src[6 + offset]);
|
||||
t[16 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(src[3 + offset], 1.625)), src[6 + offset]);
|
||||
tmp1 = vaddq_f32(vmulq_n_f32(src[1 + offset], 0.5625), src[5 + offset]);
|
||||
tmp2 = vsubq_f32(vmulq_n_f32(src[2 + offset], 0.5625), vmulq_n_f32(src[4 + offset], 2.5));
|
||||
t[24 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(src[3 + offset], 2.5)), src[6 + offset]);
|
||||
t[32 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(src[3 + offset], 2.5)), src[6 + offset]);
|
||||
tmp1 = vaddq_f32(vmulq_n_f32(src[1 + offset], 0.375), vmulq_n_f32(src[5 + offset], 1.5));
|
||||
tmp2 = vsubq_f32(vmulq_n_f32(src[2 + offset], 0.25), vmulq_n_f32(src[4 + offset], 1.25));
|
||||
t[40 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(src[3 + offset], 1.875)), src[6 + offset]);
|
||||
t[48 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(src[3 + offset], 1.875)), src[6 + offset]);
|
||||
t[56 + l] =
|
||||
vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src[1 + offset], -0.5625), vmulq_n_f32(src[3 + offset], 3.0625)),
|
||||
vmulq_n_f32(src[5 + offset], 3.5)),
|
||||
src[7 + offset]);
|
||||
}
|
||||
for (int l = 0; l < 8; ++l) {
|
||||
int offset = l * 8;
|
||||
t[l] = 0.5625f * src[offset] - 3.0625f * src[2 + offset] + 3.5f * src[4 + offset] - src[6 + offset];
|
||||
float tmp1 = 1.125f * src[1 + offset] + 0.5f * src[5 + offset];
|
||||
float tmp2 = 2.25f * src[2 + offset] - 3.25f * src[4 + offset];
|
||||
t[8 + l] = tmp1 + tmp2 - 1.625f * src[3 + offset] + src[6 + offset];
|
||||
t[16 + l] = tmp2 - tmp1 + 1.625f * src[3 + offset] + src[6 + offset];
|
||||
tmp1 = 0.5625f * src[1 + offset] + src[5 + offset];
|
||||
tmp2 = 0.5625f * src[2 + offset] - 2.5f * src[4 + offset];
|
||||
t[24 + l] = tmp1 + tmp2 - 2.5f * src[3 + offset] + src[6 + offset];
|
||||
t[32 + l] = tmp2 - tmp1 + 2.5f * src[3 + offset] + src[6 + offset];
|
||||
tmp1 = 0.375f * src[1 + offset] + 1.5f * src[5 + offset];
|
||||
tmp2 = 0.25f * src[2 + offset] - 1.25f * src[4 + offset];
|
||||
t[40 + l] = tmp1 + tmp2 - 1.875f * src[3 + offset] + src[6 + offset];
|
||||
t[48 + l] = tmp2 - tmp1 + 1.875f * src[3 + offset] + src[6 + offset];
|
||||
t[56 + l] = -0.5625f * src[1 + offset] + 3.0625f * src[3 + offset] - 3.5f * src[5 + offset] + src[7 + offset];
|
||||
m[l] = vsubq_f32(vaddq_f32(vsubq_f32(vmulq_n_f32(t[offset], 0.5625), vmulq_n_f32(t[2 + offset], 3.0625)),
|
||||
vmulq_n_f32(t[4 + offset], 3.5)),
|
||||
t[6 + offset]);
|
||||
float32x4_t tmp1 = vaddq_f32(vmulq_n_f32(t[1 + offset], 1.125), vmulq_n_f32(t[5 + offset], 0.5));
|
||||
float32x4_t tmp2 = vsubq_f32(vmulq_n_f32(t[2 + offset], 2.25), vmulq_n_f32(t[4 + offset], 3.25));
|
||||
m[8 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(t[3 + offset], 1.625)), t[6 + offset]);
|
||||
m[16 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(t[3 + offset], 1.625)), t[6 + offset]);
|
||||
tmp1 = vaddq_f32(vmulq_n_f32(t[1 + offset], 0.5625), t[5 + offset]);
|
||||
tmp2 = vsubq_f32(vmulq_n_f32(t[2 + offset], 0.5625), vmulq_n_f32(t[4 + offset], 2.5));
|
||||
m[24 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(t[3 + offset], 2.5)), t[6 + offset]);
|
||||
m[32 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(t[3 + offset], 2.5)), t[6 + offset]);
|
||||
tmp1 = vaddq_f32(vmulq_n_f32(t[1 + offset], 0.375), vmulq_n_f32(t[5 + offset], 1.5));
|
||||
tmp2 = vsubq_f32(vmulq_n_f32(t[2 + offset], 0.25), vmulq_n_f32(t[4 + offset], 1.25));
|
||||
m[40 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(t[3 + offset], 1.875)), t[6 + offset]);
|
||||
m[48 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(t[3 + offset], 1.875)), t[6 + offset]);
|
||||
m[56 + l] =
|
||||
vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t[1 + offset], -0.5625), vmulq_n_f32(t[3 + offset], 3.0625)),
|
||||
vmulq_n_f32(t[5 + offset], 3.5)),
|
||||
t[7 + offset]);
|
||||
}
|
||||
for (int l = 0; l < 8; ++l) {
|
||||
int offset = l * 8;
|
||||
m[l] = 0.5625f * t[offset] - 3.0625f * t[2 + offset] + 3.5f * t[4 + offset] - t[6 + offset];
|
||||
float tmp1 = 1.125f * t[1 + offset] + 0.5f * t[5 + offset];
|
||||
float tmp2 = 2.25f * t[2 + offset] - 3.25f * t[4 + offset];
|
||||
m[8 + l] = tmp1 + tmp2 - 1.625f * t[3 + offset] + t[6 + offset];
|
||||
m[16 + l] = tmp2 - tmp1 + 1.625f * t[3 + offset] + t[6 + offset];
|
||||
tmp1 = 0.5625f * t[1 + offset] + t[5 + offset];
|
||||
tmp2 = 0.5625f * t[2 + offset] - 2.5f * t[4 + offset];
|
||||
m[24 + l] = tmp1 + tmp2 - 2.5f * t[3 + offset] + t[6 + offset];
|
||||
m[32 + l] = tmp2 - tmp1 + 2.5f * t[3 + offset] + t[6 + offset];
|
||||
tmp1 = 0.375f * t[1 + offset] + 1.5f * t[5 + offset];
|
||||
tmp2 = 0.25f * t[2 + offset] - 1.25f * t[4 + offset];
|
||||
m[40 + l] = tmp1 + tmp2 - 1.875f * t[3 + offset] + t[6 + offset];
|
||||
m[48 + l] = tmp2 - tmp1 + 1.875f * t[3 + offset] + t[6 + offset];
|
||||
m[56 + l] = -0.5625f * t[1 + offset] + 3.0625f * t[3 + offset] - 3.5f * t[5 + offset] + t[7 + offset];
|
||||
for (int i = 0; i < 64; i++) {
|
||||
vst1q_f32(dst_data + i * dst_step, m[i]);
|
||||
}
|
||||
for (int k = 0; k < 64; ++k) {
|
||||
dst_data[i + k * dst_step] = m[k];
|
||||
} else {
|
||||
#endif
|
||||
float src[64];
|
||||
float t[64];
|
||||
float m[64];
|
||||
for (int i = 0; i < real_c; ++i) {
|
||||
for (int j = 0; j < 64; ++j) {
|
||||
src[j] = src_data[i + j * src_step];
|
||||
}
|
||||
for (int l = 0; l < 8; ++l) {
|
||||
int offset = l * 8;
|
||||
t[l] = 0.5625f * src[offset] - 3.0625f * src[2 + offset] + 3.5f * src[4 + offset] - src[6 + offset];
|
||||
float tmp1 = 1.125f * src[1 + offset] + 0.5f * src[5 + offset];
|
||||
float tmp2 = 2.25f * src[2 + offset] - 3.25f * src[4 + offset];
|
||||
t[8 + l] = tmp1 + tmp2 - 1.625f * src[3 + offset] + src[6 + offset];
|
||||
t[16 + l] = tmp2 - tmp1 + 1.625f * src[3 + offset] + src[6 + offset];
|
||||
tmp1 = 0.5625f * src[1 + offset] + src[5 + offset];
|
||||
tmp2 = 0.5625f * src[2 + offset] - 2.5f * src[4 + offset];
|
||||
t[24 + l] = tmp1 + tmp2 - 2.5f * src[3 + offset] + src[6 + offset];
|
||||
t[32 + l] = tmp2 - tmp1 + 2.5f * src[3 + offset] + src[6 + offset];
|
||||
tmp1 = 0.375f * src[1 + offset] + 1.5f * src[5 + offset];
|
||||
tmp2 = 0.25f * src[2 + offset] - 1.25f * src[4 + offset];
|
||||
t[40 + l] = tmp1 + tmp2 - 1.875f * src[3 + offset] + src[6 + offset];
|
||||
t[48 + l] = tmp2 - tmp1 + 1.875f * src[3 + offset] + src[6 + offset];
|
||||
t[56 + l] = -0.5625f * src[1 + offset] + 3.0625f * src[3 + offset] - 3.5f * src[5 + offset] + src[7 + offset];
|
||||
}
|
||||
for (int l = 0; l < 8; ++l) {
|
||||
int offset = l * 8;
|
||||
m[l] = 0.5625f * t[offset] - 3.0625f * t[2 + offset] + 3.5f * t[4 + offset] - t[6 + offset];
|
||||
float tmp1 = 1.125f * t[1 + offset] + 0.5f * t[5 + offset];
|
||||
float tmp2 = 2.25f * t[2 + offset] - 3.25f * t[4 + offset];
|
||||
m[8 + l] = tmp1 + tmp2 - 1.625f * t[3 + offset] + t[6 + offset];
|
||||
m[16 + l] = tmp2 - tmp1 + 1.625f * t[3 + offset] + t[6 + offset];
|
||||
tmp1 = 0.5625f * t[1 + offset] + t[5 + offset];
|
||||
tmp2 = 0.5625f * t[2 + offset] - 2.5f * t[4 + offset];
|
||||
m[24 + l] = tmp1 + tmp2 - 2.5f * t[3 + offset] + t[6 + offset];
|
||||
m[32 + l] = tmp2 - tmp1 + 2.5f * t[3 + offset] + t[6 + offset];
|
||||
tmp1 = 0.375f * t[1 + offset] + 1.5f * t[5 + offset];
|
||||
tmp2 = 0.25f * t[2 + offset] - 1.25f * t[4 + offset];
|
||||
m[40 + l] = tmp1 + tmp2 - 1.875f * t[3 + offset] + t[6 + offset];
|
||||
m[48 + l] = tmp2 - tmp1 + 1.875f * t[3 + offset] + t[6 + offset];
|
||||
m[56 + l] = -0.5625f * t[1 + offset] + 3.0625f * t[3 + offset] - 3.5f * t[5 + offset] + t[7 + offset];
|
||||
}
|
||||
for (int k = 0; k < 64; ++k) {
|
||||
dst_data[i + k * dst_step] = m[k];
|
||||
}
|
||||
}
|
||||
#ifdef ENABLE_ARM
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
typedef void (*InputTransFunc)(const float *src_data, float *dst_data, int src_step, int dst_step);
|
||||
typedef void (*InputTransFunc)(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
|
||||
|
||||
typedef void (*OutputTransFunc)(const float *src_data, float *dst_data, const float *bias_data, int src_step,
|
||||
int dst_step, int out_c, int r_w, int r_h, int r_c);
|
||||
|
@ -163,11 +163,11 @@ void GeneralOutputTransformUnit(const float *src_data, float *dst_data, const fl
|
|||
|
||||
InputTransFunc GetInputTransFunc(int input_unit);
|
||||
|
||||
void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step);
|
||||
void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
|
||||
|
||||
void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, int dst_step);
|
||||
void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
|
||||
|
||||
void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step);
|
||||
void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
|
||||
|
||||
OutputTransFunc GetOutputTransFunc(int input_unit, int output_unit, ActType act_type);
|
||||
|
||||
|
|
|
@ -39,21 +39,18 @@ int ConvolutionWinogradCPUKernel::WinogradFilterTransform(const float *weight_da
|
|||
// original weight format : ohwi
|
||||
auto channel_in = conv_param_->input_channel_;
|
||||
auto channel_out = conv_param_->output_channel_;
|
||||
int ic4 = UP_DIV(channel_in, C4NUM);
|
||||
int oc_block_num = UP_DIV(channel_out, oc_block);
|
||||
int c4_channel = ic4 * C4NUM;
|
||||
int block_stride = c4_channel * oc_block;
|
||||
int block_stride = channel_in * oc_block;
|
||||
int block_num_stride = block_stride * oc_block_num;
|
||||
|
||||
// trans_filter = G*g*GT (g represents weight_data)
|
||||
// separate into two steps ===> tmp = (g * GT)T ===> trans = (tmp * GT)T use same function:MatrixMultiplyWinograd
|
||||
auto tmp_data = reinterpret_cast<float *>(malloc(c4_channel * input_unit_ * kernel_unit_ * sizeof(float)));
|
||||
auto tmp_data = reinterpret_cast<float *>(malloc(channel_in * input_unit_ * kernel_unit_ * sizeof(float)));
|
||||
if (tmp_data == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc tmp_data failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
memset(tmp_data, 0, c4_channel * input_unit_ * kernel_unit_ * sizeof(float));
|
||||
auto trans_out_data = reinterpret_cast<float *>(malloc(c4_channel * input_unit_ * input_unit_ * sizeof(float)));
|
||||
auto trans_out_data = reinterpret_cast<float *>(malloc(channel_in * input_unit_ * input_unit_ * sizeof(float)));
|
||||
if (trans_out_data == nullptr) {
|
||||
free(tmp_data);
|
||||
MS_LOG(ERROR) << "malloc trans_out_data failed.";
|
||||
|
@ -61,14 +58,14 @@ int ConvolutionWinogradCPUKernel::WinogradFilterTransform(const float *weight_da
|
|||
}
|
||||
|
||||
#ifndef ENABLE_ARM64
|
||||
auto tmp_data1 = reinterpret_cast<float *>(malloc(c4_channel * input_unit_ * kernel_unit_ * sizeof(float)));
|
||||
auto tmp_data1 = reinterpret_cast<float *>(malloc(channel_in * input_unit_ * kernel_unit_ * sizeof(float)));
|
||||
if (tmp_data1 == nullptr) {
|
||||
free(tmp_data);
|
||||
free(trans_out_data);
|
||||
MS_LOG(ERROR) << "malloc tmp_data1 failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
auto trans_out_data1 = reinterpret_cast<float *>(malloc(c4_channel * input_unit_ * input_unit_ * sizeof(float)));
|
||||
auto trans_out_data1 = reinterpret_cast<float *>(malloc(channel_in * input_unit_ * input_unit_ * sizeof(float)));
|
||||
if (trans_out_data1 == nullptr) {
|
||||
free(tmp_data);
|
||||
free(tmp_data1);
|
||||
|
@ -87,30 +84,30 @@ int ConvolutionWinogradCPUKernel::WinogradFilterTransform(const float *weight_da
|
|||
#ifndef ENABLE_ARM64
|
||||
// tmp_data = g * GT
|
||||
MatrixMultiplyWinograd(weight_data + i * input_oz_offset, matrix_gt, tmp_data, kernel_unit_, kernel_unit_,
|
||||
input_unit_, channel_in, c4_channel * 4);
|
||||
input_unit_, channel_in, channel_in * 4);
|
||||
// tmp_data1 = (tmp_data)T
|
||||
PackHWCToWHC(tmp_data, tmp_data1, kernel_unit_, input_unit_, c4_channel);
|
||||
PackHWCToWHC(tmp_data, tmp_data1, kernel_unit_, input_unit_, channel_in);
|
||||
// trans_out_data1 = tmp * GT
|
||||
MatrixMultiplyWinograd(tmp_data1, matrix_gt, trans_out_data1, input_unit_, kernel_unit_, input_unit_, c4_channel,
|
||||
c4_channel * 4);
|
||||
MatrixMultiplyWinograd(tmp_data1, matrix_gt, trans_out_data1, input_unit_, kernel_unit_, input_unit_, channel_in,
|
||||
channel_in * 4);
|
||||
// trans_out_data = (trans_out_data1)T
|
||||
PackHWCToWHC(trans_out_data1, trans_out_data, input_unit_, input_unit_, c4_channel);
|
||||
PackHWCToWHC(trans_out_data1, trans_out_data, input_unit_, input_unit_, channel_in);
|
||||
#else
|
||||
// tmp = (g * GT)T
|
||||
MatrixMultiplyWinograd(weight_data + i * input_oz_offset, matrix_gt, tmp_data, kernel_unit_, kernel_unit_,
|
||||
input_unit_, channel_in, c4_channel * 4);
|
||||
input_unit_, channel_in, channel_in * 4);
|
||||
// trans = (tmp * GT)T
|
||||
MatrixMultiplyWinograd(tmp_data, matrix_gt, trans_out_data, input_unit_, kernel_unit_, input_unit_, c4_channel,
|
||||
c4_channel * 4);
|
||||
MatrixMultiplyWinograd(tmp_data, matrix_gt, trans_out_data, input_unit_, kernel_unit_, input_unit_, channel_in,
|
||||
channel_in * 4);
|
||||
#endif
|
||||
|
||||
int in_offset = 0;
|
||||
for (int j = 0; j < input_unit_; ++j) {
|
||||
for (int k = 0; k < input_unit_; ++k) {
|
||||
for (int c = 0; c < c4_channel; ++c) {
|
||||
for (int c = 0; c < channel_in; ++c) {
|
||||
*(trans_weight_ + output_oz_offset + c * oc_block) = trans_out_data[in_offset + c];
|
||||
}
|
||||
in_offset += c4_channel;
|
||||
in_offset += channel_in;
|
||||
output_oz_offset += block_num_stride;
|
||||
}
|
||||
}
|
||||
|
@ -128,7 +125,6 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() {
|
|||
auto filter_tensor = in_tensors_.at(kWeightIndex);
|
||||
int in_channel = filter_tensor->Channel();
|
||||
int out_channel = filter_tensor->Batch();
|
||||
int ic4 = UP_DIV(in_channel, C4NUM);
|
||||
conv_param_->input_channel_ = in_channel;
|
||||
conv_param_->output_channel_ = out_channel;
|
||||
|
||||
|
@ -137,7 +133,7 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() {
|
|||
int oc_block_num = UP_DIV(out_channel, C8NUM);
|
||||
|
||||
// set data
|
||||
auto trans_matrix_data_size = input_unit_ * input_unit_ * ic4 * C4NUM * oc_block_num * oc_block * sizeof(float);
|
||||
auto trans_matrix_data_size = input_unit_ * input_unit_ * in_channel * oc_block_num * oc_block * sizeof(float);
|
||||
trans_weight_ = reinterpret_cast<float *>(malloc(trans_matrix_data_size));
|
||||
if (trans_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc matrix_buffer failed.";
|
||||
|
@ -188,7 +184,6 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() {
|
|||
int ConvolutionWinogradCPUKernel::InitTmpBuffer() {
|
||||
int channel_out = conv_param_->output_channel_;
|
||||
int oc8 = UP_DIV(channel_out, C8NUM);
|
||||
int ic4 = UP_DIV(conv_param_->input_channel_, C4NUM);
|
||||
#ifdef ENABLE_ARM32
|
||||
int tile_num = 4;
|
||||
#else
|
||||
|
@ -196,7 +191,8 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() {
|
|||
#endif
|
||||
MS_ASSERT(ctx_->allocator != nullptr);
|
||||
|
||||
size_t tile_buffer_size = thread_count_ * tile_num * input_unit_ * input_unit_ * ic4 * C4NUM * sizeof(float);
|
||||
size_t tile_buffer_size =
|
||||
thread_count_ * tile_num * input_unit_ * input_unit_ * conv_param_->input_channel_ * sizeof(float);
|
||||
trans_input_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(tile_buffer_size));
|
||||
if (trans_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc trans_input_ failed.";
|
||||
|
@ -217,8 +213,8 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() {
|
|||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
|
||||
col_buffer_ =
|
||||
reinterpret_cast<float *>(ctx_->allocator->Malloc(thread_count_ * tile_num * ic4 * C4NUM * sizeof(float)));
|
||||
col_buffer_ = reinterpret_cast<float *>(
|
||||
ctx_->allocator->Malloc(thread_count_ * tile_num * conv_param_->input_channel_ * sizeof(float)));
|
||||
if (col_buffer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc col_buffer_ failed.";
|
||||
return RET_ERROR;
|
||||
|
|
Loading…
Reference in New Issue