forked from mindspore-Ecosystem/mindspore
[MSLITE][DEVELOP] optimize conv winograd
This commit is contained in:
parent
ad5c5ce5f8
commit
4b0edb34ce
|
@ -195,3 +195,4 @@ mindspore/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition
|
|||
mindspore/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition.cc:mindspore::parallel::PartitionNode
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/instance_norm_fp16.c:InstanceNormNC8HW8Fp16
|
||||
mindspore/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc:mindspore::kernel::MatmulFp32BaseCPUKernel::init_global_variable
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_winograd_fp32.c:ConvWinogardFp32
|
||||
|
|
|
@ -187,7 +187,7 @@ void Conv1x1OutNc8hw8MultiThreadByWeightFp16(const float16_t *input, float16_t *
|
|||
// fp16 convolution winograd
|
||||
void ConvWinogardFp16(const float16_t *input_data, const float16_t *trans_weight, const float16_t *bias_data,
|
||||
float16_t *output_data, TmpBufferAddressFp16 *buffer_list, int task_id,
|
||||
const ConvParameter *conv_param, InputTransFp16Func in_func, OutputTransFp16Func out_func) {
|
||||
const ConvParameter *conv_param, TransFp16FuncList trans_func) {
|
||||
#ifdef ENABLE_ARM64
|
||||
const int tile_num = 16;
|
||||
#else
|
||||
|
@ -196,6 +196,7 @@ void ConvWinogardFp16(const float16_t *input_data, const float16_t *trans_weight
|
|||
NNACL_CHECK_ZERO_RETURN(conv_param->output_unit_);
|
||||
NNACL_CHECK_ZERO_RETURN(conv_param->thread_num_);
|
||||
int in_channel = conv_param->input_channel_;
|
||||
int input_unit = conv_param->input_unit_;
|
||||
int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_);
|
||||
int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_);
|
||||
int output_count = out_w_block * out_h_block;
|
||||
|
@ -204,16 +205,12 @@ void ConvWinogardFp16(const float16_t *input_data, const float16_t *trans_weight
|
|||
NNACL_CHECK_ZERO_RETURN(real_tile);
|
||||
int output_tile_count = UP_DIV(output_count, real_tile);
|
||||
int oc8 = UP_DIV(conv_param->output_channel_, C8NUM);
|
||||
int input_unit_square = conv_param->input_unit_ * conv_param->input_unit_;
|
||||
int input_unit_square = input_unit * input_unit;
|
||||
|
||||
float16_t *trans_input = buffer_list[0];
|
||||
float16_t *gemm_out = buffer_list[1];
|
||||
float16_t *tmp_data = buffer_list[2];
|
||||
float16_t *col_buffer = buffer_list[3];
|
||||
int trans_input_offset = tile_num * input_unit_square * in_channel;
|
||||
int gemm_out_offset = tile_num * input_unit_square * oc8 * C8NUM;
|
||||
int tmp_data_offset = input_unit_square * C8NUM;
|
||||
int col_buffer_offset = tile_num * in_channel;
|
||||
float16_t *trans_input = buffer_list[0] + task_id * tile_num * input_unit_square * in_channel;
|
||||
float16_t *gemm_out = buffer_list[1] + task_id * tile_num * input_unit_square * oc8 * C8NUM;
|
||||
float16_t *tmp_data = buffer_list[2] + task_id * input_unit_square * C8NUM;
|
||||
float16_t *col_buffer = buffer_list[3] + task_id * tile_num * in_channel;
|
||||
// step 1 : filter transform (pre-processed offline)
|
||||
// step 2 : input transform (online)
|
||||
for (int b = 0; b < conv_param->input_batch_; b++) {
|
||||
|
@ -226,30 +223,64 @@ void ConvWinogardFp16(const float16_t *input_data, const float16_t *trans_weight
|
|||
if (cal_num <= 0) {
|
||||
return;
|
||||
}
|
||||
WinogradInputTransformFp16(input_data + in_batch_offset, trans_input + task_id * trans_input_offset,
|
||||
tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param,
|
||||
in_func);
|
||||
// step 3 : gemm
|
||||
float16_t *src_ptr = trans_input + task_id * trans_input_offset;
|
||||
float16_t *dst_ptr = gemm_out + task_id * gemm_out_offset;
|
||||
float16_t *tmp_col_ptr = col_buffer + task_id * col_buffer_offset;
|
||||
for (int i = 0; i < input_unit_square; ++i) {
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
RowMajor2Col16MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel);
|
||||
// Optimize input transform. Only valid for arm64, the tile num is 16.
|
||||
// For arm32, the tile_num is 12. The function(InputTransform4x4Pack12Fp16) needs to be rewritten.
|
||||
bool fused_pack =
|
||||
(cal_num == tile_num) && (trans_func.in_step_func_ != NULL) && (trans_func.in_pack_func_ != NULL);
|
||||
if (fused_pack) {
|
||||
float16_t *opt_trans_input =
|
||||
buffer_list[4] + task_id * tile_num * input_unit_square * UP_ROUND(in_channel, C8NUM);
|
||||
WinogradInputTransformOptStepFp16(input_data + in_batch_offset, opt_trans_input, tmp_data, cal_num,
|
||||
out_tile_index, out_w_block, conv_param, trans_func.in_step_func_);
|
||||
|
||||
for (int w_index = 0; w_index < input_unit; w_index++) {
|
||||
float16_t *src_w = opt_trans_input + w_index * input_unit * tile_num * C8NUM;
|
||||
for (int c = 0; c < UP_DIV(in_channel, C8NUM); c++) {
|
||||
int real_c = in_channel - c * C8NUM;
|
||||
real_c = real_c > C8NUM ? C8NUM : real_c;
|
||||
float16_t *src_c = src_w + c * input_unit_square * tile_num * C8NUM;
|
||||
float16_t *dst_c = trans_input + c * tile_num * C8NUM;
|
||||
trans_func.in_pack_func_(src_c, dst_c, C8NUM, in_channel * tile_num, real_c);
|
||||
}
|
||||
|
||||
for (int h_index = 0; h_index < input_unit; h_index++) {
|
||||
const float16_t *gemm_input = trans_input + h_index * tile_num * in_channel;
|
||||
int point_index = h_index * input_unit + w_index;
|
||||
const float16_t *gemm_weight = trans_weight + point_index * in_channel * oc8 * C8NUM;
|
||||
MatMulFp16(gemm_input, gemm_weight, gemm_out + point_index * C8NUM, NULL, 0, in_channel, cal_num,
|
||||
oc8 * C8NUM, input_unit_square, OutType_TileC8);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#endif
|
||||
WinogradInputTransformFp16(input_data + in_batch_offset, trans_input, tmp_data, cal_num, out_tile_index,
|
||||
out_w_block, conv_param, trans_func.in_func_);
|
||||
// step 3 : gemm
|
||||
float16_t *src_ptr = trans_input;
|
||||
float16_t *dst_ptr = gemm_out;
|
||||
float16_t *tmp_col_ptr = col_buffer;
|
||||
for (int i = 0; i < input_unit_square; ++i) {
|
||||
#ifdef ENABLE_ARM64
|
||||
RowMajor2Col16MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel);
|
||||
#else
|
||||
RowMajor2Col12MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel);
|
||||
#endif
|
||||
MatMulFp16(tmp_col_ptr, trans_weight + i * in_channel * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, in_channel,
|
||||
cal_num, oc8 * C8NUM, input_unit_square, OutType_TileC8);
|
||||
MatMulFp16(tmp_col_ptr, trans_weight + i * in_channel * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, in_channel,
|
||||
cal_num, oc8 * C8NUM, input_unit_square, OutType_TileC8);
|
||||
}
|
||||
#ifdef ENABLE_ARM64
|
||||
}
|
||||
#endif
|
||||
|
||||
// step 4 : output transform
|
||||
if (conv_param->out_format_ != NNACL_NC4HW4) { // nc4hw4
|
||||
WinogradOutputNHWCTransformFp16(gemm_out + task_id * gemm_out_offset, output_data + out_batch_offset, bias_data,
|
||||
cal_num, out_tile_index, out_w_block, conv_param, out_func);
|
||||
WinogradOutputNHWCTransformFp16(gemm_out, output_data + out_batch_offset, bias_data, cal_num, out_tile_index,
|
||||
out_w_block, conv_param, trans_func.out_func_);
|
||||
} else {
|
||||
WinogradOutputNC8HW8TransformFp16(gemm_out + task_id * gemm_out_offset, output_data + out_batch_offset,
|
||||
bias_data, cal_num, out_tile_index, out_w_block, conv_param, out_func);
|
||||
WinogradOutputNC8HW8TransformFp16(gemm_out, output_data + out_batch_offset, bias_data, cal_num, out_tile_index,
|
||||
out_w_block, conv_param, trans_func.out_func_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -49,7 +49,7 @@ void Conv1x1OutNc8hw8MultiThreadByWeightFp16(const float16_t *input, float16_t *
|
|||
// fp16 convolution winograd
|
||||
void ConvWinogardFp16(const float16_t *input_data, const float16_t *trans_weight, const float16_t *bias_data,
|
||||
float16_t *output_data, TmpBufferAddressFp16 *buffer_list, int task_id,
|
||||
const ConvParameter *conv_param, InputTransFp16Func in_func, OutputTransFp16Func out_func);
|
||||
const ConvParameter *conv_param, TransFp16FuncList trans_func);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -16,6 +16,77 @@
|
|||
|
||||
#include "nnacl/fp16/winograd_transform_fp16.h"
|
||||
|
||||
void PrepareTransInputFp16(const float16_t *src_data, float16_t *dst_data, int interval_x_s, int interval_x_e,
|
||||
int interval_y_s, int interval_y_e, int real_c, const ConvParameter *conv_param) {
|
||||
int input_unit = conv_param->input_unit_;
|
||||
int in_channel = conv_param->input_channel_;
|
||||
int input_w = conv_param->input_w_;
|
||||
|
||||
// clear tmp buffer
|
||||
if (interval_x_e - interval_x_s != input_unit || interval_y_e - interval_y_s != input_unit) {
|
||||
memset(dst_data, 0, input_unit * input_unit * C8NUM * sizeof(float16_t));
|
||||
}
|
||||
|
||||
// get real input block with padding
|
||||
if (real_c == C8NUM) {
|
||||
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
|
||||
int src_y_offset = (interval * input_w + interval_x_s) * in_channel;
|
||||
int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM;
|
||||
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
|
||||
int src_x_offset = src_y_offset + j * in_channel;
|
||||
int dst_x_offset = dst_y_offset + j * C8NUM;
|
||||
const float16_t *src_addr = src_data + src_x_offset;
|
||||
float16_t *dst_addr = dst_data + dst_x_offset;
|
||||
#ifdef ENABLE_NEON
|
||||
vst1q_f16(dst_addr, vld1q_f16(src_addr));
|
||||
#else
|
||||
for (int k = 0; k < C8NUM; k++) {
|
||||
dst_addr[k] = src_addr[k];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
} else if (real_c < 8 && real_c >= 4) {
|
||||
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
|
||||
int src_y_offset = (interval * input_w + interval_x_s) * in_channel;
|
||||
int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM;
|
||||
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
|
||||
int src_x_offset = src_y_offset + j * in_channel;
|
||||
int dst_x_offset = dst_y_offset + j * C8NUM;
|
||||
const float16_t *src_addr = src_data + src_x_offset;
|
||||
float16_t *dst_addr = dst_data + dst_x_offset;
|
||||
int rc = real_c - 4;
|
||||
#ifdef ENABLE_NEON
|
||||
vst1_f16(dst_addr, vld1_f16(src_addr));
|
||||
#else
|
||||
for (int k = 0; k < C4NUM; k++) {
|
||||
dst_addr[k] = src_addr[k];
|
||||
}
|
||||
#endif
|
||||
src_addr += 4;
|
||||
dst_addr += 4;
|
||||
for (int i = 0; i < rc; ++i) {
|
||||
dst_addr[i] = src_addr[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
|
||||
int src_y_offset = (interval * input_w + interval_x_s) * in_channel;
|
||||
int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM;
|
||||
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
|
||||
int src_x_offset = src_y_offset + j * in_channel;
|
||||
int dst_x_offset = dst_y_offset + j * C8NUM;
|
||||
const float16_t *src_addr = src_data + src_x_offset;
|
||||
float16_t *dst_addr = dst_data + dst_x_offset;
|
||||
for (int k = 0; k < real_c; k++) {
|
||||
dst_addr[k] = src_addr[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fp16 common winograd
|
||||
void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num,
|
||||
int out_tile_index, int out_w_block_num, const ConvParameter *conv_param,
|
||||
|
@ -49,71 +120,11 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in
|
|||
int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s);
|
||||
int dst_plane_offset = c * in_channel;
|
||||
for (int ic = 0; ic < ic8; ic++) {
|
||||
// clear tmp buffer
|
||||
memset(tmp_data, 0, input_unit * input_unit * C8NUM * sizeof(float16_t));
|
||||
|
||||
int real_c = in_channel - ic * C8NUM;
|
||||
real_c = real_c > C8NUM ? C8NUM : real_c;
|
||||
int src_ic8_offset = src_plane_offset + ic * C8NUM;
|
||||
|
||||
// get real input block with padding
|
||||
if (real_c == C8NUM) {
|
||||
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
|
||||
int src_y_offset = src_ic8_offset + (interval * input_w + interval_x_s) * in_channel;
|
||||
int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM;
|
||||
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
|
||||
int src_x_offset = src_y_offset + j * in_channel;
|
||||
int dst_x_offset = dst_y_offset + j * C8NUM;
|
||||
const float16_t *src_addr = input_data + src_x_offset;
|
||||
float16_t *dst_addr = tmp_data + dst_x_offset;
|
||||
#ifdef ENABLE_NEON
|
||||
vst1q_f16(dst_addr, vld1q_f16(src_addr));
|
||||
#else
|
||||
for (int k = 0; k < C8NUM; k++) {
|
||||
dst_addr[k] = src_addr[k];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
} else if (real_c < 8 && real_c >= 4) {
|
||||
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
|
||||
int src_y_offset = src_ic8_offset + (interval * input_w + interval_x_s) * in_channel;
|
||||
int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM;
|
||||
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
|
||||
int src_x_offset = src_y_offset + j * in_channel;
|
||||
int dst_x_offset = dst_y_offset + j * C8NUM;
|
||||
const float16_t *src_addr = input_data + src_x_offset;
|
||||
float16_t *dst_addr = tmp_data + dst_x_offset;
|
||||
int rc = real_c - 4;
|
||||
#ifdef ENABLE_NEON
|
||||
vst1_f16(dst_addr, vld1_f16(src_addr));
|
||||
#else
|
||||
for (int k = 0; k < C4NUM; k++) {
|
||||
dst_addr[k] = src_addr[k];
|
||||
}
|
||||
#endif
|
||||
src_addr += 4;
|
||||
dst_addr += 4;
|
||||
for (int i = 0; i < rc; ++i) {
|
||||
dst_addr[i] = src_addr[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
|
||||
int src_y_offset = src_ic8_offset + (interval * input_w + interval_x_s) * in_channel;
|
||||
int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM;
|
||||
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
|
||||
int src_x_offset = src_y_offset + j * in_channel;
|
||||
int dst_x_offset = dst_y_offset + j * C8NUM;
|
||||
const float16_t *src_addr = input_data + src_x_offset;
|
||||
float16_t *dst_addr = tmp_data + dst_x_offset;
|
||||
for (int k = 0; k < real_c; k++) {
|
||||
dst_addr[k] = src_addr[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
const float16_t *src_data = input_data + src_plane_offset + ic * C8NUM;
|
||||
PrepareTransInputFp16(src_data, tmp_data, interval_x_s, interval_x_e, interval_y_s, interval_y_e, real_c,
|
||||
conv_param);
|
||||
|
||||
// input transform
|
||||
int dst_ic8_offset = dst_plane_offset + ic * C8NUM;
|
||||
|
@ -125,6 +136,51 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in
|
|||
} // cal_tile_num loop
|
||||
}
|
||||
|
||||
// Only support arm64
|
||||
void WinogradInputTransformOptStepFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data,
|
||||
int cal_num, int out_tile_index, int out_w_block_num,
|
||||
const ConvParameter *conv_param, InputTransStepFp16Func func) {
|
||||
const int tile_num = 16;
|
||||
int input_unit = conv_param->input_unit_;
|
||||
int output_unit = conv_param->output_unit_;
|
||||
int in_channel = conv_param->input_channel_;
|
||||
int ic8 = UP_DIV(in_channel, C8NUM);
|
||||
int pad_h = conv_param->pad_u_;
|
||||
int pad_w = conv_param->pad_l_;
|
||||
int input_h = conv_param->input_h_;
|
||||
int input_w = conv_param->input_w_;
|
||||
if (out_w_block_num == 0) {
|
||||
return;
|
||||
}
|
||||
for (int c = 0; c < cal_num; c++) { // actual tiled number
|
||||
int src_x_s = (out_tile_index % out_w_block_num) * output_unit - pad_w;
|
||||
int src_y_s = (out_tile_index / out_w_block_num) * output_unit - pad_h;
|
||||
int interval_x_s = src_x_s > 0 ? 0 : -src_x_s;
|
||||
int interval_y_s = src_y_s > 0 ? 0 : -src_y_s;
|
||||
int src_x_e = src_x_s + input_unit;
|
||||
int src_y_e = src_y_s + input_unit;
|
||||
int interval_x_e = src_x_e < input_w ? input_unit : (input_w - src_x_s);
|
||||
int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s);
|
||||
|
||||
int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s);
|
||||
int dst_plane_offset = c * C8NUM;
|
||||
for (int ic = 0; ic < ic8; ic++) {
|
||||
int real_c = in_channel - ic * C8NUM;
|
||||
real_c = real_c > C8NUM ? C8NUM : real_c;
|
||||
const float16_t *src_data = input_data + src_plane_offset + ic * C8NUM;
|
||||
PrepareTransInputFp16(src_data, tmp_data, interval_x_s, interval_x_e, interval_y_s, interval_y_e, real_c,
|
||||
conv_param);
|
||||
|
||||
// input transform
|
||||
int dst_ic8_offset = dst_plane_offset + ic * tile_num * input_unit * input_unit * C8NUM;
|
||||
size_t dst_step = input_unit * tile_num * C8NUM;
|
||||
float16_t *trans_input_ptr = trans_input + dst_ic8_offset;
|
||||
func(tmp_data, trans_input_ptr, C8NUM, dst_step, tile_num * C8NUM);
|
||||
}
|
||||
out_tile_index++;
|
||||
} // cal_tile_num loop
|
||||
}
|
||||
|
||||
void WinogradOutputNHWCTransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data,
|
||||
int cal_num, int out_tile_index, int output_unit_num,
|
||||
const ConvParameter *conv_param, OutputTransFp16Func func) {
|
||||
|
|
|
@ -33,6 +33,10 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in
|
|||
int out_tile_index, int out_w_block_num, const ConvParameter *conv_param,
|
||||
InputTransFp16Func func);
|
||||
|
||||
void WinogradInputTransformOptStepFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data,
|
||||
int cal_num, int out_tile_index, int out_w_block_num,
|
||||
const ConvParameter *conv_param, InputTransStepFp16Func func);
|
||||
|
||||
void WinogradOutputNHWCTransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data,
|
||||
int cal_num, int out_tile_index, int output_unit_num,
|
||||
const ConvParameter *conv_param, OutputTransFp16Func func);
|
||||
|
|
|
@ -20,6 +20,38 @@
|
|||
#define MIN_UNIT_FP16 2
|
||||
#define MAX_UNIT_FP16 4
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void transpose8(float16x8_t *s0, float16x8_t *s1, float16x8_t *s2, float16x8_t *s3, float16x8_t *s4, float16x8_t *s5,
|
||||
float16x8_t *s6, float16x8_t *s7) {
|
||||
float32x4_t m0 = (float32x4_t)(vtrn1q_f16(*s0, *s1));
|
||||
float32x4_t m1 = (float32x4_t)(vtrn2q_f16(*s0, *s1));
|
||||
float32x4_t m2 = (float32x4_t)(vtrn1q_f16(*s2, *s3));
|
||||
float32x4_t m3 = (float32x4_t)(vtrn2q_f16(*s2, *s3));
|
||||
float32x4_t m4 = (float32x4_t)(vtrn1q_f16(*s4, *s5));
|
||||
float32x4_t m5 = (float32x4_t)(vtrn2q_f16(*s4, *s5));
|
||||
float32x4_t m6 = (float32x4_t)(vtrn1q_f16(*s6, *s7));
|
||||
float32x4_t m7 = (float32x4_t)(vtrn2q_f16(*s6, *s7));
|
||||
|
||||
float64x2_t t0 = (float64x2_t)(vtrn1q_f32(m0, m2));
|
||||
float64x2_t t2 = (float64x2_t)(vtrn2q_f32(m0, m2));
|
||||
float64x2_t t1 = (float64x2_t)(vtrn1q_f32(m1, m3));
|
||||
float64x2_t t3 = (float64x2_t)(vtrn2q_f32(m1, m3));
|
||||
float64x2_t t4 = (float64x2_t)(vtrn1q_f32(m4, m6));
|
||||
float64x2_t t6 = (float64x2_t)(vtrn2q_f32(m4, m6));
|
||||
float64x2_t t5 = (float64x2_t)(vtrn1q_f32(m5, m7));
|
||||
float64x2_t t7 = (float64x2_t)(vtrn2q_f32(m5, m7));
|
||||
|
||||
*s0 = (float16x8_t)(vtrn1q_f64(t0, t4));
|
||||
*s4 = (float16x8_t)(vtrn2q_f64(t0, t4));
|
||||
*s1 = (float16x8_t)(vtrn1q_f64(t1, t5));
|
||||
*s5 = (float16x8_t)(vtrn2q_f64(t1, t5));
|
||||
*s2 = (float16x8_t)(vtrn1q_f64(t2, t6));
|
||||
*s6 = (float16x8_t)(vtrn2q_f64(t2, t6));
|
||||
*s3 = (float16x8_t)(vtrn1q_f64(t3, t7));
|
||||
*s7 = (float16x8_t)(vtrn2q_f64(t3, t7));
|
||||
}
|
||||
#endif
|
||||
|
||||
static InputTransFp16Func InputTransFp16FuncList[] = {
|
||||
NULL, NULL, NULL, NULL, InputTransform4x4UnitFp16, NULL, InputTransform6x6UnitFp16, NULL, InputTransform8x8UnitFp16};
|
||||
|
||||
|
@ -81,6 +113,25 @@ static OutputTransFp16Func OutputTransFp16FuncRelu6List8[] = {NULL,
|
|||
|
||||
InputTransFp16Func GetInputTransFp16Func(int input_unit) { return InputTransFp16FuncList[input_unit]; }
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
static InputTransStepFp16Func InputTransStepFp16FuncList[] = {
|
||||
NULL, NULL, NULL, NULL, InputTransform4x4StepFp16, NULL, InputTransform6x6StepFp16, NULL, InputTransform8x8StepFp16};
|
||||
|
||||
static InputTransPackFp16Func InputTransPackFp16FuncList[] = {NULL,
|
||||
NULL,
|
||||
NULL,
|
||||
NULL,
|
||||
InputTransform4x4Pack16Fp16,
|
||||
NULL,
|
||||
InputTransform6x6Pack16Fp16,
|
||||
NULL,
|
||||
InputTransform8x8Pack16Fp16};
|
||||
|
||||
InputTransStepFp16Func GetInputTransStepFp16Func(int input_unit) { return InputTransStepFp16FuncList[input_unit]; }
|
||||
|
||||
InputTransPackFp16Func GetInputTransPackFp16Func(int input_unit) { return InputTransPackFp16FuncList[input_unit]; }
|
||||
#endif
|
||||
|
||||
void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) {
|
||||
int j = 0;
|
||||
if (real_c == 8) {
|
||||
|
@ -160,6 +211,74 @@ void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, i
|
|||
}
|
||||
}
|
||||
|
||||
void InputTransform4x4StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step,
|
||||
int dst_row_step) {
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
const float16_t *src_ptr = src_data + l * 4 * src_step;
|
||||
float16_t *dst_ptr = dst_data + l * dst_row_step;
|
||||
|
||||
float16x8_t s0 = vld1q_f16(src_ptr + 0 * src_step);
|
||||
float16x8_t s1 = vld1q_f16(src_ptr + 1 * src_step);
|
||||
float16x8_t s2 = vld1q_f16(src_ptr + 2 * src_step);
|
||||
float16x8_t s3 = vld1q_f16(src_ptr + 3 * src_step);
|
||||
float16x8_t m0 = vsubq_f16(s0, s2);
|
||||
float16x8_t m1 = vaddq_f16(s1, s2);
|
||||
float16x8_t m2 = vsubq_f16(s2, s1);
|
||||
float16x8_t m3 = vsubq_f16(s3, s1);
|
||||
|
||||
vst1q_f16(dst_ptr + 0 * dst_step, m0);
|
||||
vst1q_f16(dst_ptr + 1 * dst_step, m1);
|
||||
vst1q_f16(dst_ptr + 2 * dst_step, m2);
|
||||
vst1q_f16(dst_ptr + 3 * dst_step, m3);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void InputTransform4x4Pack16ChannelFp16(float16_t *src_ptr, float16_t *dst_ptr, int dst_step, int pack_tile,
|
||||
int src_point_stride) {
|
||||
LOAD_LINE_DATA_FP16(0);
|
||||
LOAD_LINE_DATA_FP16(1);
|
||||
LOAD_LINE_DATA_FP16(2);
|
||||
LOAD_LINE_DATA_FP16(3);
|
||||
|
||||
float16x8_t m0 = vsubq_f16(s00, s20);
|
||||
float16x8_t m1 = vsubq_f16(s01, s21);
|
||||
vst1q_f16(dst_ptr + 0 * dst_step + 0 * pack_tile, m0);
|
||||
vst1q_f16(dst_ptr + 0 * dst_step + 1 * pack_tile, m1);
|
||||
|
||||
m0 = vaddq_f16(s10, s20);
|
||||
m1 = vaddq_f16(s11, s21);
|
||||
vst1q_f16(dst_ptr + 1 * dst_step + 0 * pack_tile, m0);
|
||||
vst1q_f16(dst_ptr + 1 * dst_step + 1 * pack_tile, m1);
|
||||
|
||||
m0 = vsubq_f16(s20, s10);
|
||||
m1 = vsubq_f16(s21, s11);
|
||||
vst1q_f16(dst_ptr + 2 * dst_step + 0 * pack_tile, m0);
|
||||
vst1q_f16(dst_ptr + 2 * dst_step + 1 * pack_tile, m1);
|
||||
|
||||
m0 = vsubq_f16(s30, s10);
|
||||
m1 = vsubq_f16(s31, s11);
|
||||
vst1q_f16(dst_ptr + 3 * dst_step + 0 * pack_tile, m0);
|
||||
vst1q_f16(dst_ptr + 3 * dst_step + 1 * pack_tile, m1);
|
||||
}
|
||||
|
||||
void InputTransform4x4Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) {
|
||||
int block_tile = 16;
|
||||
int pack_tile = src_step;
|
||||
int src_point_stride = block_tile * pack_tile;
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
float16_t *src_ptr = src_data + l * C8NUM * block_tile;
|
||||
TRANSPOSE_16x8;
|
||||
}
|
||||
|
||||
for (int c = 0; c < real_c; ++c) {
|
||||
float16_t *src_ptr = src_data + c * block_tile;
|
||||
float16_t *dst_ptr = dst_data + c * block_tile;
|
||||
InputTransform4x4Pack16ChannelFp16(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void InputTransform6x6UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) {
|
||||
int j = 0;
|
||||
if (real_c == 8) {
|
||||
|
@ -272,6 +391,95 @@ void InputTransform6x6UnitFp16(const float16_t *src_data, float16_t *dst_data, i
|
|||
}
|
||||
}
|
||||
|
||||
void InputTransform6x6StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step,
|
||||
int dst_row_step) {
|
||||
for (int l = 0; l < 6; ++l) {
|
||||
const float16_t *src_ptr = src_data + l * 6 * src_step;
|
||||
float16_t *dst_ptr = dst_data + l * dst_row_step;
|
||||
|
||||
float16x8_t s0 = vld1q_f16(src_ptr + 0 * src_step);
|
||||
float16x8_t s1 = vld1q_f16(src_ptr + 1 * src_step);
|
||||
float16x8_t s2 = vld1q_f16(src_ptr + 2 * src_step);
|
||||
float16x8_t s3 = vld1q_f16(src_ptr + 3 * src_step);
|
||||
float16x8_t s4 = vld1q_f16(src_ptr + 4 * src_step);
|
||||
float16x8_t s5 = vld1q_f16(src_ptr + 5 * src_step);
|
||||
|
||||
float16x8_t tmp1 = vsubq_f16(s3, s1);
|
||||
float16x8_t tmp2 = vsubq_f16(s4, s2);
|
||||
float16x8_t m0 = vaddq_f16(vsubq_f16(vmulq_n_f16(s0, 4), vmulq_n_f16(s2, 5)), s4);
|
||||
float16x8_t m1 = vaddq_f16(vmulq_n_f16(vaddq_f16(s1, s2), -4), vaddq_f16(s3, s4));
|
||||
float16x8_t m2 = vaddq_f16(vmulq_n_f16(vsubq_f16(s1, s2), 4), vsubq_f16(s4, s3));
|
||||
float16x8_t m3 = vaddq_f16(vmulq_n_f16(tmp1, 2), tmp2);
|
||||
float16x8_t m4 = vaddq_f16(vmulq_n_f16(tmp1, -2), tmp2);
|
||||
float16x8_t m5 = vaddq_f16(vsubq_f16(vmulq_n_f16(s1, 4), vmulq_n_f16(s3, 5)), s5);
|
||||
|
||||
vst1q_f16(dst_ptr + 0 * dst_step, m0);
|
||||
vst1q_f16(dst_ptr + 1 * dst_step, m1);
|
||||
vst1q_f16(dst_ptr + 2 * dst_step, m2);
|
||||
vst1q_f16(dst_ptr + 3 * dst_step, m3);
|
||||
vst1q_f16(dst_ptr + 4 * dst_step, m4);
|
||||
vst1q_f16(dst_ptr + 5 * dst_step, m5);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void InputTransform6x6Pack16ChannelFp16(float16_t *src_ptr, float16_t *dst_ptr, int dst_step, int pack_tile,
|
||||
int src_point_stride) {
|
||||
LOAD_LINE_DATA_FP16(0);
|
||||
LOAD_LINE_DATA_FP16(1);
|
||||
LOAD_LINE_DATA_FP16(2);
|
||||
LOAD_LINE_DATA_FP16(3);
|
||||
LOAD_LINE_DATA_FP16(4);
|
||||
LOAD_LINE_DATA_FP16(5);
|
||||
|
||||
float16x8_t m0 = vaddq_f16(vsubq_f16(vmulq_n_f16(s00, 4), vmulq_n_f16(s20, 5)), s40);
|
||||
float16x8_t m1 = vaddq_f16(vsubq_f16(vmulq_n_f16(s01, 4), vmulq_n_f16(s21, 5)), s41);
|
||||
vst1q_f16(dst_ptr + 0 * dst_step + 0 * pack_tile, m0);
|
||||
vst1q_f16(dst_ptr + 0 * dst_step + 1 * pack_tile, m1);
|
||||
|
||||
m0 = vaddq_f16(vmulq_n_f16(vaddq_f16(s10, s20), -4), vaddq_f16(s30, s40));
|
||||
m1 = vaddq_f16(vmulq_n_f16(vaddq_f16(s11, s21), -4), vaddq_f16(s31, s41));
|
||||
vst1q_f16(dst_ptr + 1 * dst_step + 0 * pack_tile, m0);
|
||||
vst1q_f16(dst_ptr + 1 * dst_step + 1 * pack_tile, m1);
|
||||
|
||||
m0 = vaddq_f16(vmulq_n_f16(vsubq_f16(s10, s20), 4), vsubq_f16(s40, s30));
|
||||
m1 = vaddq_f16(vmulq_n_f16(vsubq_f16(s11, s21), 4), vsubq_f16(s41, s31));
|
||||
vst1q_f16(dst_ptr + 2 * dst_step + 0 * pack_tile, m0);
|
||||
vst1q_f16(dst_ptr + 2 * dst_step + 1 * pack_tile, m1);
|
||||
|
||||
m0 = vaddq_f16(vmulq_n_f16(vsubq_f16(s30, s10), 2), vsubq_f16(s40, s20));
|
||||
m1 = vaddq_f16(vmulq_n_f16(vsubq_f16(s31, s11), 2), vsubq_f16(s41, s21));
|
||||
vst1q_f16(dst_ptr + 3 * dst_step + 0 * pack_tile, m0);
|
||||
vst1q_f16(dst_ptr + 3 * dst_step + 1 * pack_tile, m1);
|
||||
|
||||
m0 = vaddq_f16(vmulq_n_f16(vsubq_f16(s30, s10), -2), vsubq_f16(s40, s20));
|
||||
m1 = vaddq_f16(vmulq_n_f16(vsubq_f16(s31, s11), -2), vsubq_f16(s41, s21));
|
||||
vst1q_f16(dst_ptr + 4 * dst_step + 0 * pack_tile, m0);
|
||||
vst1q_f16(dst_ptr + 4 * dst_step + 1 * pack_tile, m1);
|
||||
|
||||
m0 = vaddq_f16(vsubq_f16(vmulq_n_f16(s10, 4), vmulq_n_f16(s30, 5)), s50);
|
||||
m1 = vaddq_f16(vsubq_f16(vmulq_n_f16(s11, 4), vmulq_n_f16(s31, 5)), s51);
|
||||
vst1q_f16(dst_ptr + 5 * dst_step + 0 * pack_tile, m0);
|
||||
vst1q_f16(dst_ptr + 5 * dst_step + 1 * pack_tile, m1);
|
||||
}
|
||||
|
||||
void InputTransform6x6Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) {
|
||||
int block_tile = 16;
|
||||
int pack_tile = src_step;
|
||||
int src_point_stride = block_tile * pack_tile;
|
||||
for (int l = 0; l < 6; ++l) {
|
||||
float16_t *src_ptr = src_data + l * C8NUM * block_tile;
|
||||
TRANSPOSE_16x8;
|
||||
}
|
||||
|
||||
for (int c = 0; c < real_c; ++c) {
|
||||
float16_t *src_ptr = src_data + c * block_tile;
|
||||
float16_t *dst_ptr = dst_data + c * block_tile;
|
||||
InputTransform6x6Pack16ChannelFp16(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) {
|
||||
int j = 0;
|
||||
if (real_c == 8) {
|
||||
|
@ -429,6 +637,133 @@ void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, i
|
|||
}
|
||||
}
|
||||
|
||||
void InputTransform8x8StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step,
|
||||
int dst_row_step) {
|
||||
for (int l = 0; l < 8; ++l) {
|
||||
const float16_t *src_ptr = src_data + l * 8 * src_step;
|
||||
float16_t *dst_ptr = dst_data + l * dst_row_step;
|
||||
|
||||
float16x8_t s0 = vld1q_f16(src_ptr + 0 * src_step);
|
||||
float16x8_t s1 = vld1q_f16(src_ptr + 1 * src_step);
|
||||
float16x8_t s2 = vld1q_f16(src_ptr + 2 * src_step);
|
||||
float16x8_t s3 = vld1q_f16(src_ptr + 3 * src_step);
|
||||
float16x8_t s4 = vld1q_f16(src_ptr + 4 * src_step);
|
||||
float16x8_t s5 = vld1q_f16(src_ptr + 5 * src_step);
|
||||
float16x8_t s6 = vld1q_f16(src_ptr + 6 * src_step);
|
||||
float16x8_t s7 = vld1q_f16(src_ptr + 7 * src_step);
|
||||
|
||||
float16x8_t m0 =
|
||||
vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(s0, 0.5625), vmulq_n_f16(s2, 3.0625)), vmulq_n_f16(s4, 3.5)), s6);
|
||||
float16x8_t tmp1 = vaddq_f16(vmulq_n_f16(s1, 1.125), vmulq_n_f16(s5, 0.5));
|
||||
float16x8_t tmp2 = vsubq_f16(vmulq_n_f16(s2, 2.25), vmulq_n_f16(s4, 3.25));
|
||||
float16x8_t m1 = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(s3, 1.625)), s6);
|
||||
float16x8_t m2 = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(s3, 1.625)), s6);
|
||||
tmp1 = vaddq_f16(vmulq_n_f16(s1, 0.5625), s5);
|
||||
tmp2 = vsubq_f16(vmulq_n_f16(s2, 0.5625), vmulq_n_f16(s4, 2.5));
|
||||
float16x8_t m3 = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(s3, 2.5)), s6);
|
||||
float16x8_t m4 = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(s3, 2.5)), s6);
|
||||
tmp1 = vaddq_f16(vmulq_n_f16(s1, 0.375), vmulq_n_f16(s5, 1.5));
|
||||
tmp2 = vsubq_f16(vmulq_n_f16(s2, 0.25), vmulq_n_f16(s4, 1.25));
|
||||
float16x8_t m5 = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(s3, 1.875)), s6);
|
||||
float16x8_t m6 = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(s3, 1.875)), s6);
|
||||
float16x8_t m7 =
|
||||
vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(s1, -0.5625), vmulq_n_f16(s3, 3.0625)), vmulq_n_f16(s5, 3.5)), s7);
|
||||
|
||||
vst1q_f16(dst_ptr + 0 * dst_step, m0);
|
||||
vst1q_f16(dst_ptr + 1 * dst_step, m1);
|
||||
vst1q_f16(dst_ptr + 2 * dst_step, m2);
|
||||
vst1q_f16(dst_ptr + 3 * dst_step, m3);
|
||||
vst1q_f16(dst_ptr + 4 * dst_step, m4);
|
||||
vst1q_f16(dst_ptr + 5 * dst_step, m5);
|
||||
vst1q_f16(dst_ptr + 6 * dst_step, m6);
|
||||
vst1q_f16(dst_ptr + 7 * dst_step, m7);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void InputTransform8x8Pack16ChannelFp16(float16_t *src_ptr, float16_t *dst_ptr, int dst_step, int pack_tile,
|
||||
int src_point_stride) {
|
||||
LOAD_LINE_DATA_FP16(0);
|
||||
LOAD_LINE_DATA_FP16(1);
|
||||
LOAD_LINE_DATA_FP16(2);
|
||||
LOAD_LINE_DATA_FP16(3);
|
||||
LOAD_LINE_DATA_FP16(4);
|
||||
LOAD_LINE_DATA_FP16(5);
|
||||
LOAD_LINE_DATA_FP16(6);
|
||||
LOAD_LINE_DATA_FP16(7);
|
||||
|
||||
float16x8_t m0 =
|
||||
vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(s00, 0.5625), vmulq_n_f16(s20, 3.0625)), vmulq_n_f16(s40, 3.5)), s60);
|
||||
float16x8_t m1 =
|
||||
vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(s01, 0.5625), vmulq_n_f16(s21, 3.0625)), vmulq_n_f16(s41, 3.5)), s61);
|
||||
vst1q_f16(dst_ptr + 0 * dst_step + 0 * pack_tile, m0);
|
||||
vst1q_f16(dst_ptr + 0 * dst_step + 1 * pack_tile, m1);
|
||||
|
||||
float16x8_t tmp10 = vaddq_f16(vmulq_n_f16(s10, 1.125), vmulq_n_f16(s50, 0.5));
|
||||
float16x8_t tmp11 = vaddq_f16(vmulq_n_f16(s11, 1.125), vmulq_n_f16(s51, 0.5));
|
||||
float16x8_t tmp20 = vsubq_f16(vmulq_n_f16(s20, 2.25), vmulq_n_f16(s40, 3.25));
|
||||
float16x8_t tmp21 = vsubq_f16(vmulq_n_f16(s21, 2.25), vmulq_n_f16(s41, 3.25));
|
||||
m0 = vaddq_f16(vsubq_f16(vaddq_f16(tmp10, tmp20), vmulq_n_f16(s30, 1.625)), s60);
|
||||
m1 = vaddq_f16(vsubq_f16(vaddq_f16(tmp11, tmp21), vmulq_n_f16(s31, 1.625)), s61);
|
||||
vst1q_f16(dst_ptr + 1 * dst_step + 0 * pack_tile, m0);
|
||||
vst1q_f16(dst_ptr + 1 * dst_step + 1 * pack_tile, m1);
|
||||
|
||||
m0 = vaddq_f16(vaddq_f16(vsubq_f16(tmp20, tmp10), vmulq_n_f16(s30, 1.625)), s60);
|
||||
m1 = vaddq_f16(vaddq_f16(vsubq_f16(tmp21, tmp11), vmulq_n_f16(s31, 1.625)), s61);
|
||||
vst1q_f16(dst_ptr + 2 * dst_step + 0 * pack_tile, m0);
|
||||
vst1q_f16(dst_ptr + 2 * dst_step + 1 * pack_tile, m1);
|
||||
|
||||
tmp10 = vaddq_f16(vmulq_n_f16(s10, 0.5625), s50);
|
||||
tmp11 = vaddq_f16(vmulq_n_f16(s11, 0.5625), s51);
|
||||
tmp20 = vsubq_f16(vmulq_n_f16(s20, 0.5625), vmulq_n_f16(s40, 2.5));
|
||||
tmp21 = vsubq_f16(vmulq_n_f16(s21, 0.5625), vmulq_n_f16(s41, 2.5));
|
||||
m0 = vaddq_f16(vsubq_f16(vaddq_f16(tmp10, tmp20), vmulq_n_f16(s30, 2.5)), s60);
|
||||
m1 = vaddq_f16(vsubq_f16(vaddq_f16(tmp11, tmp21), vmulq_n_f16(s31, 2.5)), s61);
|
||||
vst1q_f16(dst_ptr + 3 * dst_step + 0 * pack_tile, m0);
|
||||
vst1q_f16(dst_ptr + 3 * dst_step + 1 * pack_tile, m1);
|
||||
|
||||
m0 = vaddq_f16(vaddq_f16(vsubq_f16(tmp20, tmp10), vmulq_n_f16(s30, 2.5)), s60);
|
||||
m1 = vaddq_f16(vaddq_f16(vsubq_f16(tmp21, tmp11), vmulq_n_f16(s31, 2.5)), s61);
|
||||
vst1q_f16(dst_ptr + 4 * dst_step + 0 * pack_tile, m0);
|
||||
vst1q_f16(dst_ptr + 4 * dst_step + 1 * pack_tile, m1);
|
||||
|
||||
tmp10 = vaddq_f16(vmulq_n_f16(s10, 0.375), vmulq_n_f16(s50, 1.5));
|
||||
tmp11 = vaddq_f16(vmulq_n_f16(s11, 0.375), vmulq_n_f16(s51, 1.5));
|
||||
tmp20 = vsubq_f16(vmulq_n_f16(s20, 0.25), vmulq_n_f16(s40, 1.25));
|
||||
tmp21 = vsubq_f16(vmulq_n_f16(s21, 0.25), vmulq_n_f16(s41, 1.25));
|
||||
m0 = vaddq_f16(vsubq_f16(vaddq_f16(tmp10, tmp20), vmulq_n_f16(s30, 1.875)), s60);
|
||||
m1 = vaddq_f16(vsubq_f16(vaddq_f16(tmp11, tmp21), vmulq_n_f16(s31, 1.875)), s61);
|
||||
vst1q_f16(dst_ptr + 5 * dst_step + 0 * pack_tile, m0);
|
||||
vst1q_f16(dst_ptr + 5 * dst_step + 1 * pack_tile, m1);
|
||||
|
||||
m0 = vaddq_f16(vaddq_f16(vsubq_f16(tmp20, tmp10), vmulq_n_f16(s30, 1.875)), s60);
|
||||
m1 = vaddq_f16(vaddq_f16(vsubq_f16(tmp21, tmp11), vmulq_n_f16(s31, 1.875)), s61);
|
||||
vst1q_f16(dst_ptr + 6 * dst_step + 0 * pack_tile, m0);
|
||||
vst1q_f16(dst_ptr + 6 * dst_step + 1 * pack_tile, m1);
|
||||
|
||||
m0 = vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(s10, -0.5625), vmulq_n_f16(s30, 3.0625)), vmulq_n_f16(s50, 3.5)), s70);
|
||||
m1 = vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(s11, -0.5625), vmulq_n_f16(s31, 3.0625)), vmulq_n_f16(s51, 3.5)), s71);
|
||||
vst1q_f16(dst_ptr + 7 * dst_step + 0 * pack_tile, m0);
|
||||
vst1q_f16(dst_ptr + 7 * dst_step + 1 * pack_tile, m1);
|
||||
}
|
||||
|
||||
void InputTransform8x8Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) {
|
||||
int block_tile = 16;
|
||||
int pack_tile = src_step;
|
||||
int src_point_stride = block_tile * pack_tile;
|
||||
for (int l = 0; l < 8; ++l) {
|
||||
float16_t *src_ptr = src_data + l * C8NUM * block_tile;
|
||||
TRANSPOSE_16x8;
|
||||
}
|
||||
|
||||
for (int c = 0; c < real_c; ++c) {
|
||||
float16_t *src_ptr = src_data + c * block_tile;
|
||||
float16_t *dst_ptr = dst_data + c * block_tile;
|
||||
InputTransform8x8Pack16ChannelFp16(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActType act_type) {
|
||||
if (input_unit == 4 && output_unit < 4) {
|
||||
if (act_type == ActType_Relu) {
|
||||
|
|
|
@ -29,9 +29,22 @@ extern "C" {
|
|||
typedef void (*InputTransFp16Func)(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step,
|
||||
int real_c);
|
||||
|
||||
typedef void (*InputTransStepFp16Func)(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step,
|
||||
int dst_row_step);
|
||||
|
||||
typedef void (*InputTransPackFp16Func)(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step,
|
||||
int real_c);
|
||||
|
||||
typedef void (*OutputTransFp16Func)(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
|
||||
int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c);
|
||||
|
||||
typedef struct TransFp16FuncList {
|
||||
InputTransFp16Func in_func_;
|
||||
InputTransStepFp16Func in_step_func_;
|
||||
InputTransPackFp16Func in_pack_func_;
|
||||
OutputTransFp16Func out_func_;
|
||||
} TransFp16FuncList;
|
||||
|
||||
#define Load16DataFp16 \
|
||||
src[0] = vld1q_f16(src_data + 0 * src_step); \
|
||||
src[1] = vld1q_f16(src_data + 1 * src_step); \
|
||||
|
@ -276,14 +289,77 @@ typedef void (*OutputTransFp16Func)(const float16_t *src_data, float16_t *dst_da
|
|||
src[62] = vld1_f16(src_data + 62 * src_step); \
|
||||
src[63] = vld1_f16(src_data + 63 * src_step);
|
||||
|
||||
#define LOAD_LINE_DATA_FP16(line) \
|
||||
float16x8_t s##line##0 = vld1q_f16(src_ptr + line * src_point_stride + 0 * pack_tile); \
|
||||
float16x8_t s##line##1 = vld1q_f16(src_ptr + line * src_point_stride + 1 * pack_tile);
|
||||
|
||||
#define TRANSPOSE_16x8 \
|
||||
float16x8_t s0 = vld1q_f16(src_ptr + 0 * pack_tile); \
|
||||
float16x8_t s2 = vld1q_f16(src_ptr + 1 * pack_tile); \
|
||||
float16x8_t s4 = vld1q_f16(src_ptr + 2 * pack_tile); \
|
||||
float16x8_t s6 = vld1q_f16(src_ptr + 3 * pack_tile); \
|
||||
float16x8_t s8 = vld1q_f16(src_ptr + 4 * pack_tile); \
|
||||
float16x8_t s10 = vld1q_f16(src_ptr + 5 * pack_tile); \
|
||||
float16x8_t s12 = vld1q_f16(src_ptr + 6 * pack_tile); \
|
||||
float16x8_t s14 = vld1q_f16(src_ptr + 7 * pack_tile); \
|
||||
float16x8_t s1 = vld1q_f16(src_ptr + 8 * pack_tile); \
|
||||
float16x8_t s3 = vld1q_f16(src_ptr + 9 * pack_tile); \
|
||||
float16x8_t s5 = vld1q_f16(src_ptr + 10 * pack_tile); \
|
||||
float16x8_t s7 = vld1q_f16(src_ptr + 11 * pack_tile); \
|
||||
float16x8_t s9 = vld1q_f16(src_ptr + 12 * pack_tile); \
|
||||
float16x8_t s11 = vld1q_f16(src_ptr + 13 * pack_tile); \
|
||||
float16x8_t s13 = vld1q_f16(src_ptr + 14 * pack_tile); \
|
||||
float16x8_t s15 = vld1q_f16(src_ptr + 15 * pack_tile); \
|
||||
transpose8(&s0, &s2, &s4, &s6, &s8, &s10, &s12, &s14); \
|
||||
transpose8(&s1, &s3, &s5, &s7, &s9, &s11, &s13, &s15); \
|
||||
vst1q_f16(src_ptr + 0 * pack_tile, s0); \
|
||||
vst1q_f16(src_ptr + 1 * pack_tile, s1); \
|
||||
vst1q_f16(src_ptr + 2 * pack_tile, s2); \
|
||||
vst1q_f16(src_ptr + 3 * pack_tile, s3); \
|
||||
vst1q_f16(src_ptr + 4 * pack_tile, s4); \
|
||||
vst1q_f16(src_ptr + 5 * pack_tile, s5); \
|
||||
vst1q_f16(src_ptr + 6 * pack_tile, s6); \
|
||||
vst1q_f16(src_ptr + 7 * pack_tile, s7); \
|
||||
vst1q_f16(src_ptr + 8 * pack_tile, s8); \
|
||||
vst1q_f16(src_ptr + 9 * pack_tile, s9); \
|
||||
vst1q_f16(src_ptr + 10 * pack_tile, s10); \
|
||||
vst1q_f16(src_ptr + 11 * pack_tile, s11); \
|
||||
vst1q_f16(src_ptr + 12 * pack_tile, s12); \
|
||||
vst1q_f16(src_ptr + 13 * pack_tile, s13); \
|
||||
vst1q_f16(src_ptr + 14 * pack_tile, s14); \
|
||||
vst1q_f16(src_ptr + 15 * pack_tile, s15);
|
||||
|
||||
InputTransFp16Func GetInputTransFp16Func(int input_unit);
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
InputTransStepFp16Func GetInputTransStepFp16Func(int input_unit);
|
||||
|
||||
InputTransPackFp16Func GetInputTransPackFp16Func(int input_unit);
|
||||
#endif
|
||||
|
||||
void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c);
|
||||
|
||||
void InputTransform6x6UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c);
|
||||
|
||||
void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c);
|
||||
|
||||
void InputTransform4x4StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step,
|
||||
int dst_row_step);
|
||||
|
||||
void InputTransform6x6StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step,
|
||||
int dst_row_step);
|
||||
|
||||
void InputTransform8x8StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step,
|
||||
int dst_row_step);
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void InputTransform4x4Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c);
|
||||
|
||||
void InputTransform6x6Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c);
|
||||
|
||||
void InputTransform8x8Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c);
|
||||
#endif
|
||||
|
||||
OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActType act_type);
|
||||
|
||||
#define Store4DataFp16 \
|
||||
|
|
|
@ -23,11 +23,12 @@
|
|||
// fp32 conv winograd
|
||||
void ConvWinogardFp32(const float *input_data, const float *trans_weight, const float *bias_data, float *output_data,
|
||||
TmpBufferAddress *buffer_list, int task_id, const ConvParameter *conv_param,
|
||||
InputTransFunc in_func, OutputTransFunc out_func) {
|
||||
TransFuncList trans_func) {
|
||||
if (conv_param->output_unit_ == 0) {
|
||||
return;
|
||||
}
|
||||
int in_channel = conv_param->input_channel_;
|
||||
int input_unit = conv_param->input_unit_;
|
||||
int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_);
|
||||
int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_);
|
||||
int output_count = out_w_block * out_h_block;
|
||||
|
@ -35,26 +36,19 @@ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const
|
|||
int output_tile_count = UP_DIV(output_count, tile_num);
|
||||
#ifdef ENABLE_AVX
|
||||
const int col_tile = C16NUM;
|
||||
const int tmp_data_tile = C8NUM;
|
||||
const int channel_pack_tile = C8NUM;
|
||||
#else
|
||||
const int col_tile = C8NUM;
|
||||
const int tmp_data_tile = C4NUM;
|
||||
const int channel_pack_tile = C4NUM;
|
||||
#endif
|
||||
int oc_tile = UP_DIV(conv_param->output_channel_, col_tile);
|
||||
int oc8 = UP_DIV(conv_param->output_channel_, C8NUM);
|
||||
int input_unit_square = conv_param->input_unit_ * conv_param->input_unit_;
|
||||
if (input_unit_square < conv_param->input_unit_) {
|
||||
return;
|
||||
}
|
||||
int input_unit_square = input_unit * input_unit;
|
||||
|
||||
float *trans_input = buffer_list[0];
|
||||
float *gemm_out = buffer_list[1];
|
||||
float *tmp_data = buffer_list[2];
|
||||
float *col_buffer = buffer_list[3];
|
||||
int trans_input_offset = tile_num * input_unit_square * in_channel;
|
||||
int gemm_out_offset = tile_num * input_unit_square * oc8 * C8NUM;
|
||||
int tmp_data_offset = input_unit_square * tmp_data_tile;
|
||||
int col_buffer_offset = tile_num * in_channel;
|
||||
float *trans_input = buffer_list[0] + task_id * tile_num * input_unit_square * in_channel;
|
||||
float *gemm_out = buffer_list[1] + task_id * tile_num * input_unit_square * oc8 * C8NUM;
|
||||
float *tmp_data = buffer_list[2] + task_id * input_unit_square * channel_pack_tile;
|
||||
float *col_buffer = buffer_list[3] + task_id * tile_num * in_channel;
|
||||
// step 1 : filter transform (pre-processed offline)
|
||||
// step 2 : input transform (online)
|
||||
for (int b = 0; b < conv_param->input_batch_; b++) {
|
||||
|
@ -67,37 +61,75 @@ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const
|
|||
if (cal_num <= 0) {
|
||||
return;
|
||||
}
|
||||
WinogradInputTransform(input_data + in_batch_offset, trans_input + task_id * trans_input_offset,
|
||||
tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param,
|
||||
in_func);
|
||||
// step 3 : gemm
|
||||
float *src_ptr = trans_input + task_id * trans_input_offset;
|
||||
float *dst_ptr = gemm_out + task_id * gemm_out_offset;
|
||||
float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset;
|
||||
for (int i = 0; i < input_unit_square; ++i) {
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
// Optimize input transform. Only valid for arm64, the tile num is 12, the channel_tile is 4.
|
||||
// For arm32, the tile_num is 4.
|
||||
// For x86_sse, the tile_num is 4, the channel_tile is 4.
|
||||
// For avx, the tile_num is 6, the channel_tile is 8.
|
||||
// N = input_unit, M = tile_num
|
||||
// The function(InputTransformNxNStep, InputTransform4x4PackM) needs to be rewritten.
|
||||
bool fused_pack =
|
||||
(cal_num == tile_num) && (trans_func.in_step_func_ != NULL) && (trans_func.in_pack_func_ != NULL);
|
||||
if (fused_pack) {
|
||||
float *opt_trans_input =
|
||||
buffer_list[4] + task_id * tile_num * input_unit_square * UP_ROUND(in_channel, channel_pack_tile);
|
||||
WinogradInputTransformOptStep(input_data + in_batch_offset, opt_trans_input, tmp_data, cal_num, out_tile_index,
|
||||
out_w_block, conv_param, trans_func.in_step_func_);
|
||||
|
||||
for (int w_index = 0; w_index < input_unit; w_index++) {
|
||||
float *src_w = opt_trans_input + w_index * input_unit * tile_num * channel_pack_tile;
|
||||
for (int c = 0; c < UP_DIV(in_channel, channel_pack_tile); c++) {
|
||||
int real_c = in_channel - c * channel_pack_tile;
|
||||
real_c = real_c > channel_pack_tile ? channel_pack_tile : real_c;
|
||||
float *src_c = src_w + c * input_unit_square * tile_num * channel_pack_tile;
|
||||
float *dst_c = trans_input + c * tile_num * channel_pack_tile;
|
||||
trans_func.in_pack_func_(src_c, dst_c, channel_pack_tile, in_channel * tile_num, real_c);
|
||||
}
|
||||
|
||||
for (int h_index = 0; h_index < input_unit; h_index++) {
|
||||
const float *gemm_input = trans_input + h_index * tile_num * in_channel;
|
||||
int point_index = h_index * input_unit + w_index;
|
||||
const float *gemm_weight = trans_weight + point_index * in_channel * oc_tile * col_tile;
|
||||
MatMulOpt(gemm_input, gemm_weight, gemm_out + point_index * C8NUM, NULL, 0, in_channel, cal_num,
|
||||
oc8 * C8NUM, input_unit_square, OutType_TileC8);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#endif
|
||||
WinogradInputTransform(input_data + in_batch_offset, trans_input, tmp_data, cal_num, out_tile_index,
|
||||
out_w_block, conv_param, trans_func.in_func_);
|
||||
// step 3 : gemm
|
||||
float *src_ptr = trans_input;
|
||||
float *dst_ptr = gemm_out;
|
||||
float *tmp_col_ptr = col_buffer;
|
||||
for (int i = 0; i < input_unit_square; ++i) {
|
||||
#ifdef ENABLE_AVX
|
||||
RowMajor2Col6Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel);
|
||||
RowMajor2Col6Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel);
|
||||
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
|
||||
RowMajor2Col4Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel);
|
||||
#else
|
||||
RowMajor2Col12Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel);
|
||||
#endif
|
||||
MatMulOpt(tmp_col_ptr, trans_weight + i * in_channel * oc_tile * col_tile, dst_ptr + i * C8NUM, NULL, 0,
|
||||
in_channel, cal_num, oc8 * C8NUM, input_unit_square, 2);
|
||||
MatMulOpt(tmp_col_ptr, trans_weight + i * in_channel * oc_tile * col_tile, dst_ptr + i * C8NUM, NULL, 0,
|
||||
in_channel, cal_num, oc8 * C8NUM, input_unit_square, 2);
|
||||
}
|
||||
#ifdef ENABLE_ARM64
|
||||
}
|
||||
#endif
|
||||
|
||||
// step 4 : output transform
|
||||
float *output_ptr = output_data + out_batch_offset;
|
||||
if (conv_param->out_format_ != NNACL_NC4HW4) { // nc4hw4
|
||||
WinogradOutputNHWCTransform(dst_ptr, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param,
|
||||
out_func);
|
||||
WinogradOutputNHWCTransform(gemm_out, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param,
|
||||
trans_func.out_func_);
|
||||
} else {
|
||||
#if defined(ENABLE_AVX) || defined(ENABLE_ARM64)
|
||||
WinogradOutputNC4HW4Transform(dst_ptr, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param,
|
||||
out_func);
|
||||
WinogradOutputNC4HW4Transform(gemm_out, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param,
|
||||
trans_func.out_func_);
|
||||
#else
|
||||
WinogradOutputNHWCTransform(dst_ptr, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param,
|
||||
out_func);
|
||||
WinogradOutputNHWCTransform(gemm_out, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param,
|
||||
trans_func.out_func_);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ extern "C" {
|
|||
// fp32 convolution winograd
|
||||
void ConvWinogardFp32(const float *input_data, const float *trans_weight, const float *bias_data, float *output_data,
|
||||
TmpBufferAddress *buffer_list, int task_id, const ConvParameter *conv_param,
|
||||
InputTransFunc in_func, OutputTransFunc out_func);
|
||||
TransFuncList trans_func);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -17,6 +17,59 @@
|
|||
#include "nnacl/fp32/winograd_transform.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
void PrepareTransInput(const float *src_data, float *dst_data, int interval_x_s, int interval_x_e, int interval_y_s,
|
||||
int interval_y_e, int real_c, const ConvParameter *conv_param) {
|
||||
int input_unit = conv_param->input_unit_;
|
||||
int in_channel = conv_param->input_channel_;
|
||||
int input_w = conv_param->input_w_;
|
||||
#ifdef ENABLE_AVX
|
||||
int channel_tile = C8NUM;
|
||||
#else
|
||||
int channel_tile = C4NUM;
|
||||
#endif
|
||||
// clear tmp buffer
|
||||
if (interval_x_e - interval_x_s != input_unit || interval_y_e - interval_y_s != input_unit) {
|
||||
memset(dst_data, 0, input_unit * input_unit * channel_tile * (int)(sizeof(float)));
|
||||
}
|
||||
|
||||
// get real input block with padding
|
||||
if (real_c == channel_tile) {
|
||||
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
|
||||
int src_y_offset = (interval * input_w + interval_x_s) * in_channel;
|
||||
int dst_y_offset = interval * input_unit * channel_tile + interval_x_s * channel_tile;
|
||||
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
|
||||
int src_x_offset = src_y_offset + j * in_channel;
|
||||
int dst_x_offset = dst_y_offset + j * channel_tile;
|
||||
const float *src_addr = src_data + src_x_offset;
|
||||
float *dst_addr = dst_data + dst_x_offset;
|
||||
#ifdef ENABLE_AVX
|
||||
MS_ST256_F32(dst_addr, MS_LD256_F32(src_addr));
|
||||
#elif defined(ENABLE_ARM) || defined(ENABLE_SSE)
|
||||
MS_STQ_F32(dst_addr, MS_LDQ_F32(src_addr));
|
||||
#else
|
||||
for (int k = 0; k < channel_tile; k++) {
|
||||
dst_addr[k] = src_addr[k];
|
||||
}
|
||||
#endif
|
||||
} // interval x loop
|
||||
} // interval y loop
|
||||
} else {
|
||||
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
|
||||
int src_y_offset = (interval * input_w + interval_x_s) * in_channel;
|
||||
int dst_y_offset = interval * input_unit * channel_tile + interval_x_s * channel_tile;
|
||||
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
|
||||
int src_x_offset = src_y_offset + j * in_channel;
|
||||
int dst_x_offset = dst_y_offset + j * channel_tile;
|
||||
const float *src_addr = src_data + src_x_offset;
|
||||
float *dst_addr = dst_data + dst_x_offset;
|
||||
for (int k = 0; k < real_c; k++) {
|
||||
dst_addr[k] = src_addr[k];
|
||||
}
|
||||
} // interval x loop
|
||||
} // interval y loop
|
||||
}
|
||||
}
|
||||
|
||||
// fp32 conv winograd
|
||||
void WinogradInputTransform(const float *input_data, float *trans_input, float *tmp_data, int cal_num,
|
||||
int out_tile_index, int out_w_block_num, const ConvParameter *conv_param,
|
||||
|
@ -25,11 +78,11 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
|
|||
int output_unit = conv_param->output_unit_;
|
||||
int in_channel = conv_param->input_channel_;
|
||||
#ifdef ENABLE_AVX
|
||||
int tile = C8NUM;
|
||||
int channel_tile = C8NUM;
|
||||
#else
|
||||
int tile = C4NUM;
|
||||
int channel_tile = C4NUM;
|
||||
#endif
|
||||
int ic4 = UP_DIV(in_channel, tile);
|
||||
int ic4 = UP_DIV(in_channel, channel_tile);
|
||||
int pad_h = conv_param->pad_u_;
|
||||
int pad_w = conv_param->pad_l_;
|
||||
int input_h = conv_param->input_h_;
|
||||
|
@ -49,54 +102,61 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
|
|||
int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s);
|
||||
int dst_plane_offset = c * in_channel;
|
||||
for (int ic = 0; ic < ic4; ic++) {
|
||||
// clear tmp buffer
|
||||
memset(tmp_data, 0, input_unit * input_unit * tile * (int)(sizeof(float)));
|
||||
int real_c = in_channel - ic * channel_tile;
|
||||
real_c = real_c > channel_tile ? channel_tile : real_c;
|
||||
const float *src_data = input_data + src_plane_offset + ic * channel_tile;
|
||||
PrepareTransInput(src_data, tmp_data, interval_x_s, interval_x_e, interval_y_s, interval_y_e, real_c, conv_param);
|
||||
|
||||
int real_c = in_channel - ic * tile;
|
||||
real_c = real_c > tile ? tile : real_c;
|
||||
int src_ic4_offset = src_plane_offset + ic * tile;
|
||||
// get real input block with padding
|
||||
if (real_c == tile) {
|
||||
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
|
||||
int src_y_offset = src_ic4_offset + (interval * input_w + interval_x_s) * in_channel;
|
||||
int dst_y_offset = interval * input_unit * tile + interval_x_s * tile;
|
||||
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
|
||||
int src_x_offset = src_y_offset + j * in_channel;
|
||||
int dst_x_offset = dst_y_offset + j * tile;
|
||||
float *src_addr = (float *)(input_data) + src_x_offset;
|
||||
float *dst_addr = tmp_data + dst_x_offset;
|
||||
#ifdef ENABLE_AVX
|
||||
MS_ST256_F32(dst_addr, MS_LD256_F32(src_addr));
|
||||
#elif defined(ENABLE_ARM) || defined(ENABLE_SSE)
|
||||
MS_STQ_F32(dst_addr, MS_LDQ_F32(src_addr));
|
||||
#else
|
||||
for (int k = 0; k < tile; k++) {
|
||||
dst_addr[k] = src_addr[k];
|
||||
}
|
||||
#endif
|
||||
} // interval x loop
|
||||
} // interval y loop
|
||||
} else {
|
||||
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
|
||||
int src_y_offset = src_ic4_offset + (interval * input_w + interval_x_s) * in_channel;
|
||||
int dst_y_offset = interval * input_unit * tile + interval_x_s * tile;
|
||||
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
|
||||
int src_x_offset = src_y_offset + j * in_channel;
|
||||
int dst_x_offset = dst_y_offset + j * tile;
|
||||
float *src_addr = (float *)(input_data) + src_x_offset;
|
||||
float *dst_addr = tmp_data + dst_x_offset;
|
||||
for (int k = 0; k < real_c; k++) {
|
||||
dst_addr[k] = src_addr[k];
|
||||
}
|
||||
} // interval x loop
|
||||
} // interval y loop
|
||||
}
|
||||
// input transform
|
||||
const int tile_num = C12NUM;
|
||||
int dst_ic4_offset = dst_plane_offset + ic * tile;
|
||||
int dst_ic4_offset = dst_plane_offset + ic * channel_tile;
|
||||
int dst_step = tile_num * in_channel;
|
||||
float *trans_input_ptr = trans_input + dst_ic4_offset;
|
||||
func(tmp_data, trans_input_ptr, tile, dst_step, real_c);
|
||||
func(tmp_data, trans_input_ptr, channel_tile, dst_step, real_c);
|
||||
}
|
||||
out_tile_index++;
|
||||
} // cal_tile_num loop
|
||||
}
|
||||
|
||||
// Only support arm64
|
||||
void WinogradInputTransformOptStep(const float *input_data, float *trans_input, float *tmp_data, int cal_num,
|
||||
int out_tile_index, int out_w_block_num, const ConvParameter *conv_param,
|
||||
InputTransStepFunc func) {
|
||||
int input_unit = conv_param->input_unit_;
|
||||
int output_unit = conv_param->output_unit_;
|
||||
int in_channel = conv_param->input_channel_;
|
||||
int channel_tile = C4NUM;
|
||||
int ic4 = UP_DIV(in_channel, channel_tile);
|
||||
int pad_h = conv_param->pad_u_;
|
||||
int pad_w = conv_param->pad_l_;
|
||||
int input_h = conv_param->input_h_;
|
||||
int input_w = conv_param->input_w_;
|
||||
NNACL_CHECK_ZERO_RETURN(out_w_block_num);
|
||||
|
||||
for (int c = 0; c < cal_num; c++) { // actual tiled number
|
||||
int src_x_s = (out_tile_index % out_w_block_num) * output_unit - pad_w;
|
||||
int src_y_s = (out_tile_index / out_w_block_num) * output_unit - pad_h;
|
||||
int interval_x_s = src_x_s > 0 ? 0 : -src_x_s;
|
||||
int interval_y_s = src_y_s > 0 ? 0 : -src_y_s;
|
||||
int src_x_e = src_x_s + input_unit;
|
||||
int src_y_e = src_y_s + input_unit;
|
||||
int interval_x_e = src_x_e < input_w ? input_unit : (input_w - src_x_s);
|
||||
int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s);
|
||||
|
||||
int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s);
|
||||
int dst_plane_offset = c * channel_tile;
|
||||
for (int ic = 0; ic < ic4; ic++) {
|
||||
int real_c = in_channel - ic * channel_tile;
|
||||
real_c = real_c > channel_tile ? channel_tile : real_c;
|
||||
const float *src_data = input_data + src_plane_offset + ic * channel_tile;
|
||||
PrepareTransInput(src_data, tmp_data, interval_x_s, interval_x_e, interval_y_s, interval_y_e, real_c, conv_param);
|
||||
|
||||
// input transform
|
||||
const int block_tile = C12NUM;
|
||||
int dst_ic8_offset = dst_plane_offset + ic * block_tile * input_unit * input_unit * channel_tile;
|
||||
size_t dst_step = input_unit * block_tile * channel_tile;
|
||||
float *trans_input_ptr = trans_input + dst_ic8_offset;
|
||||
func(tmp_data, trans_input_ptr, channel_tile, dst_step, block_tile * channel_tile);
|
||||
}
|
||||
out_tile_index++;
|
||||
} // cal_tile_num loop
|
||||
|
|
|
@ -32,6 +32,10 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
|
|||
int out_tile_index, int out_w_block_num, const ConvParameter *conv_param,
|
||||
InputTransFunc func);
|
||||
|
||||
void WinogradInputTransformOptStep(const float *input_data, float *trans_input, float *tmp_data, int cal_num,
|
||||
int out_tile_index, int out_w_block_num, const ConvParameter *conv_param,
|
||||
InputTransStepFunc func);
|
||||
|
||||
void WinogradOutputNHWCTransform(const float *gemm_out, float *out_data, const float *bias_data, int cal_num,
|
||||
int out_tile_index, int output_unit_num, const ConvParameter *conv_param,
|
||||
OutputTransFunc func);
|
||||
|
|
|
@ -20,6 +20,19 @@
|
|||
#include "nnacl/base/conv_common_base.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void transpose4(MS_FLOAT32X4 *s0, MS_FLOAT32X4 *s1, MS_FLOAT32X4 *s2, MS_FLOAT32X4 *s3) {
|
||||
float64x2_t m0 = (float64x2_t)(vtrn1q_f32(*s0, *s1));
|
||||
float64x2_t m1 = (float64x2_t)(vtrn2q_f32(*s0, *s1));
|
||||
float64x2_t m2 = (float64x2_t)(vtrn1q_f32(*s2, *s3));
|
||||
float64x2_t m3 = (float64x2_t)(vtrn2q_f32(*s2, *s3));
|
||||
*s0 = (float32x4_t)(vtrn1q_f64(m0, m2));
|
||||
*s2 = (float32x4_t)(vtrn2q_f64(m0, m2));
|
||||
*s1 = (float32x4_t)(vtrn1q_f64(m1, m3));
|
||||
*s3 = (float32x4_t)(vtrn2q_f64(m1, m3));
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_AVX
|
||||
static InputTransFunc InputTransFuncList[] = {
|
||||
NULL, NULL, NULL, NULL, InputTransform4x4AvxUnit, NULL, InputTransform6x6AvxUnit, NULL, InputTransform8x8AvxUnit};
|
||||
|
@ -55,6 +68,18 @@ static OutputTransFunc OutputTransFuncList[] = {
|
|||
|
||||
InputTransFunc GetInputTransFunc(int input_unit) { return InputTransFuncList[input_unit]; }
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
static InputTransStepFunc InputTransStepFuncList[] = {
|
||||
NULL, NULL, NULL, NULL, InputTransform4x4Step, NULL, InputTransform6x6Step, NULL, InputTransform8x8Step};
|
||||
|
||||
static InputTransPackFunc InputTransPackFuncList[] = {
|
||||
NULL, NULL, NULL, NULL, InputTransform4x4Pack12, NULL, InputTransform6x6Pack12, NULL, InputTransform8x8Pack12};
|
||||
|
||||
InputTransStepFunc GetInputTransStepFunc(int input_unit) { return InputTransStepFuncList[input_unit]; }
|
||||
|
||||
InputTransPackFunc GetInputTransPackFunc(int input_unit) { return InputTransPackFuncList[input_unit]; }
|
||||
#endif
|
||||
|
||||
void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) {
|
||||
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
|
||||
if (real_c == 4) {
|
||||
|
@ -138,6 +163,136 @@ void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step,
|
|||
#endif
|
||||
}
|
||||
|
||||
void InputTransform4x4Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step) {
|
||||
#ifdef ENABLE_ARM64
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
const float *src_ptr = src_data + l * 4 * src_step;
|
||||
float *dst_ptr = dst_data + l * dst_row_step;
|
||||
|
||||
MS_FLOAT32X4 s0 = MS_LDQ_F32(src_ptr + 0 * src_step);
|
||||
MS_FLOAT32X4 s1 = MS_LDQ_F32(src_ptr + 1 * src_step);
|
||||
MS_FLOAT32X4 s2 = MS_LDQ_F32(src_ptr + 2 * src_step);
|
||||
MS_FLOAT32X4 s3 = MS_LDQ_F32(src_ptr + 3 * src_step);
|
||||
MS_FLOAT32X4 m0 = MS_SUBQ_F32(s0, s2);
|
||||
MS_FLOAT32X4 m1 = MS_ADDQ_F32(s1, s2);
|
||||
MS_FLOAT32X4 m2 = MS_SUBQ_F32(s2, s1);
|
||||
MS_FLOAT32X4 m3 = MS_SUBQ_F32(s3, s1);
|
||||
|
||||
MS_STQ_F32(dst_ptr + 0 * dst_step, m0);
|
||||
MS_STQ_F32(dst_ptr + 1 * dst_step, m1);
|
||||
MS_STQ_F32(dst_ptr + 2 * dst_step, m2);
|
||||
MS_STQ_F32(dst_ptr + 3 * dst_step, m3);
|
||||
}
|
||||
#else
|
||||
float src[4];
|
||||
float m[4];
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
for (int w = 0; w < 4; ++w) {
|
||||
int tmp_index = l * 4 + w;
|
||||
src[w] = src_data[i + tmp_index * src_step];
|
||||
}
|
||||
m[0] = src[0] - src[2];
|
||||
m[1] = src[1] + src[2];
|
||||
m[2] = src[2] - src[1];
|
||||
m[3] = src[3] - src[1];
|
||||
|
||||
float *dst = dst_data + l * dst_row_step;
|
||||
for (int w = 0; w < 4; ++w) {
|
||||
dst[i + w * dst_step] = m[w];
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void InputTransform4x4Pack12Channel(float *src_ptr, float *dst_ptr, int dst_step, int pack_tile, int src_point_stride) {
|
||||
LOAD_LINE_DATA(0);
|
||||
LOAD_LINE_DATA(1);
|
||||
LOAD_LINE_DATA(2);
|
||||
LOAD_LINE_DATA(3);
|
||||
|
||||
MS_FLOAT32X4 m0 = MS_SUBQ_F32(s00, s20);
|
||||
MS_FLOAT32X4 m1 = MS_SUBQ_F32(s01, s21);
|
||||
MS_FLOAT32X4 m2 = MS_SUBQ_F32(s02, s22);
|
||||
MS_STQ_F32(dst_ptr + 0 * dst_step + 0 * pack_tile, m0);
|
||||
MS_STQ_F32(dst_ptr + 0 * dst_step + 1 * pack_tile, m1);
|
||||
MS_STQ_F32(dst_ptr + 0 * dst_step + 2 * pack_tile, m2);
|
||||
|
||||
m0 = MS_ADDQ_F32(s10, s20);
|
||||
m1 = MS_ADDQ_F32(s11, s21);
|
||||
m2 = MS_ADDQ_F32(s12, s22);
|
||||
MS_STQ_F32(dst_ptr + 1 * dst_step + 0 * pack_tile, m0);
|
||||
MS_STQ_F32(dst_ptr + 1 * dst_step + 1 * pack_tile, m1);
|
||||
MS_STQ_F32(dst_ptr + 1 * dst_step + 2 * pack_tile, m2);
|
||||
|
||||
m0 = MS_SUBQ_F32(s20, s10);
|
||||
m1 = MS_SUBQ_F32(s21, s11);
|
||||
m2 = MS_SUBQ_F32(s22, s12);
|
||||
MS_STQ_F32(dst_ptr + 2 * dst_step + 0 * pack_tile, m0);
|
||||
MS_STQ_F32(dst_ptr + 2 * dst_step + 1 * pack_tile, m1);
|
||||
MS_STQ_F32(dst_ptr + 2 * dst_step + 2 * pack_tile, m2);
|
||||
|
||||
m0 = MS_SUBQ_F32(s30, s10);
|
||||
m1 = MS_SUBQ_F32(s31, s11);
|
||||
m2 = MS_SUBQ_F32(s32, s12);
|
||||
MS_STQ_F32(dst_ptr + 3 * dst_step + 0 * pack_tile, m0);
|
||||
MS_STQ_F32(dst_ptr + 3 * dst_step + 1 * pack_tile, m1);
|
||||
MS_STQ_F32(dst_ptr + 3 * dst_step + 2 * pack_tile, m2);
|
||||
}
|
||||
#endif
|
||||
|
||||
void InputTransform4x4Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c) {
|
||||
int block_tile = 12;
|
||||
int pack_tile = src_step;
|
||||
int src_point_stride = block_tile * pack_tile;
|
||||
#ifdef ENABLE_ARM64
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
float *src_ptr = src_data + l * C4NUM * block_tile;
|
||||
TRANSPOSE_12x4;
|
||||
}
|
||||
|
||||
for (int c = 0; c < real_c; ++c) {
|
||||
float *src_ptr = src_data + c * block_tile;
|
||||
float *dst_ptr = dst_data + c * block_tile;
|
||||
InputTransform4x4Pack12Channel(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride);
|
||||
}
|
||||
#else
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
float *src = src_data + l * pack_tile * block_tile;
|
||||
// 12 * 4 -> 4 * 12
|
||||
float tmp_mat[pack_tile][block_tile];
|
||||
for (int i = 0; i < block_tile; ++i) {
|
||||
for (int j = 0; j < pack_tile; ++j) {
|
||||
tmp_mat[j][i] = src[i * pack_tile + j];
|
||||
}
|
||||
}
|
||||
memcpy(src, tmp_mat, pack_tile * block_tile * sizeof(float));
|
||||
}
|
||||
|
||||
float src[4];
|
||||
float m[4];
|
||||
for (int c = 0; c < real_c; ++c) {
|
||||
for (int i = 0; i < block_tile; ++i) {
|
||||
int tmp_index = c * block_tile + i;
|
||||
for (int w = 0; w < 4; ++w) {
|
||||
src[w] = src_data[tmp_index + w * src_point_stride];
|
||||
}
|
||||
|
||||
m[0] = src[0] - src[2];
|
||||
m[1] = src[1] + src[2];
|
||||
m[2] = src[2] - src[1];
|
||||
m[3] = src[3] - src[1];
|
||||
|
||||
for (int w = 0; w < 4; ++w) {
|
||||
dst_data[tmp_index + w * dst_step] = m[w];
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) {
|
||||
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
|
||||
if (real_c == 4) {
|
||||
|
@ -217,6 +372,169 @@ void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step,
|
|||
#endif
|
||||
}
|
||||
|
||||
void InputTransform6x6Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step) {
|
||||
#ifdef ENABLE_ARM64
|
||||
for (int l = 0; l < 6; ++l) {
|
||||
const float *src_ptr = src_data + l * 6 * src_step;
|
||||
float *dst_ptr = dst_data + l * dst_row_step;
|
||||
|
||||
MS_FLOAT32X4 s0 = MS_LDQ_F32(src_ptr + 0 * src_step);
|
||||
MS_FLOAT32X4 s1 = MS_LDQ_F32(src_ptr + 1 * src_step);
|
||||
MS_FLOAT32X4 s2 = MS_LDQ_F32(src_ptr + 2 * src_step);
|
||||
MS_FLOAT32X4 s3 = MS_LDQ_F32(src_ptr + 3 * src_step);
|
||||
MS_FLOAT32X4 s4 = MS_LDQ_F32(src_ptr + 4 * src_step);
|
||||
MS_FLOAT32X4 s5 = MS_LDQ_F32(src_ptr + 5 * src_step);
|
||||
|
||||
MS_FLOAT32X4 tmp1 = MS_SUBQ_F32(s3, s1);
|
||||
MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(s4, s2);
|
||||
MS_FLOAT32X4 m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s0, 4), MS_MULQ_N_F32(s2, 5)), s4);
|
||||
MS_FLOAT32X4 m1 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(s1, s2), -4), MS_ADDQ_F32(s3, s4));
|
||||
MS_FLOAT32X4 m2 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s1, s2), 4), MS_SUBQ_F32(s4, s3));
|
||||
MS_FLOAT32X4 m3 = MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 2), tmp2);
|
||||
MS_FLOAT32X4 m4 = MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, -2), tmp2);
|
||||
MS_FLOAT32X4 m5 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s1, 4), MS_MULQ_N_F32(s3, 5)), s5);
|
||||
|
||||
MS_STQ_F32(dst_ptr + 0 * dst_step, m0);
|
||||
MS_STQ_F32(dst_ptr + 1 * dst_step, m1);
|
||||
MS_STQ_F32(dst_ptr + 2 * dst_step, m2);
|
||||
MS_STQ_F32(dst_ptr + 3 * dst_step, m3);
|
||||
MS_STQ_F32(dst_ptr + 4 * dst_step, m4);
|
||||
MS_STQ_F32(dst_ptr + 5 * dst_step, m5);
|
||||
}
|
||||
#else
|
||||
float src[6];
|
||||
float m[6];
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
for (int l = 0; l < 6; ++l) {
|
||||
for (int w = 0; w < 6; ++w) {
|
||||
int tmp_index = l * 6 + w;
|
||||
src[w] = src_data[i + tmp_index * src_step];
|
||||
}
|
||||
float tmp1 = src[3] - src[1];
|
||||
float tmp2 = src[4] - src[2];
|
||||
m[0] = 4 * src[0] - 5 * src[2] + src[4];
|
||||
m[1] = -4 * (src[1] + src[2]) + (src[3] + src[4]);
|
||||
m[2] = 4 * (src[1] - src[2]) + (src[4] - src[3]);
|
||||
m[3] = 2 * tmp1 + tmp2;
|
||||
m[4] = -2 * tmp1 + tmp2;
|
||||
m[5] = 4 * src[1] - 5 * src[3] + src[5];
|
||||
|
||||
float *dst = dst_data + l * dst_row_step;
|
||||
for (int w = 0; w < 6; ++w) {
|
||||
dst[i + w * dst_step] = m[w];
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void InputTransform6x6Pack12Channel(float *src_ptr, float *dst_ptr, int dst_step, int pack_tile, int src_point_stride) {
|
||||
LOAD_LINE_DATA(0);
|
||||
LOAD_LINE_DATA(1);
|
||||
LOAD_LINE_DATA(2);
|
||||
LOAD_LINE_DATA(3);
|
||||
LOAD_LINE_DATA(4);
|
||||
LOAD_LINE_DATA(5);
|
||||
|
||||
MS_FLOAT32X4 m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s00, 4), MS_MULQ_N_F32(s20, 5)), s40);
|
||||
MS_FLOAT32X4 m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s01, 4), MS_MULQ_N_F32(s21, 5)), s41);
|
||||
MS_FLOAT32X4 m2 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s02, 4), MS_MULQ_N_F32(s22, 5)), s42);
|
||||
MS_STQ_F32(dst_ptr + 0 * dst_step + 0 * pack_tile, m0);
|
||||
MS_STQ_F32(dst_ptr + 0 * dst_step + 1 * pack_tile, m1);
|
||||
MS_STQ_F32(dst_ptr + 0 * dst_step + 2 * pack_tile, m2);
|
||||
|
||||
m0 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(s10, s20), -4), MS_ADDQ_F32(s30, s40));
|
||||
m1 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(s11, s21), -4), MS_ADDQ_F32(s31, s41));
|
||||
m2 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(s12, s22), -4), MS_ADDQ_F32(s32, s42));
|
||||
MS_STQ_F32(dst_ptr + 1 * dst_step + 0 * pack_tile, m0);
|
||||
MS_STQ_F32(dst_ptr + 1 * dst_step + 1 * pack_tile, m1);
|
||||
MS_STQ_F32(dst_ptr + 1 * dst_step + 2 * pack_tile, m2);
|
||||
|
||||
m0 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s10, s20), 4), MS_SUBQ_F32(s40, s30));
|
||||
m1 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s11, s21), 4), MS_SUBQ_F32(s41, s31));
|
||||
m2 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s12, s22), 4), MS_SUBQ_F32(s42, s32));
|
||||
MS_STQ_F32(dst_ptr + 2 * dst_step + 0 * pack_tile, m0);
|
||||
MS_STQ_F32(dst_ptr + 2 * dst_step + 1 * pack_tile, m1);
|
||||
MS_STQ_F32(dst_ptr + 2 * dst_step + 2 * pack_tile, m2);
|
||||
|
||||
m0 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s30, s10), 2), MS_SUBQ_F32(s40, s20));
|
||||
m1 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s31, s11), 2), MS_SUBQ_F32(s41, s21));
|
||||
m2 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s32, s12), 2), MS_SUBQ_F32(s42, s22));
|
||||
MS_STQ_F32(dst_ptr + 3 * dst_step + 0 * pack_tile, m0);
|
||||
MS_STQ_F32(dst_ptr + 3 * dst_step + 1 * pack_tile, m1);
|
||||
MS_STQ_F32(dst_ptr + 3 * dst_step + 2 * pack_tile, m2);
|
||||
|
||||
m0 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s30, s10), -2), MS_SUBQ_F32(s40, s20));
|
||||
m1 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s31, s11), -2), MS_SUBQ_F32(s41, s21));
|
||||
m2 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s32, s12), -2), MS_SUBQ_F32(s42, s22));
|
||||
MS_STQ_F32(dst_ptr + 4 * dst_step + 0 * pack_tile, m0);
|
||||
MS_STQ_F32(dst_ptr + 4 * dst_step + 1 * pack_tile, m1);
|
||||
MS_STQ_F32(dst_ptr + 4 * dst_step + 2 * pack_tile, m2);
|
||||
|
||||
m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s10, 4), MS_MULQ_N_F32(s30, 5)), s50);
|
||||
m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s11, 4), MS_MULQ_N_F32(s31, 5)), s51);
|
||||
m2 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s12, 4), MS_MULQ_N_F32(s32, 5)), s52);
|
||||
MS_STQ_F32(dst_ptr + 5 * dst_step + 0 * pack_tile, m0);
|
||||
MS_STQ_F32(dst_ptr + 5 * dst_step + 1 * pack_tile, m1);
|
||||
MS_STQ_F32(dst_ptr + 5 * dst_step + 2 * pack_tile, m2);
|
||||
}
|
||||
#endif
|
||||
|
||||
void InputTransform6x6Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c) {
|
||||
int block_tile = 12;
|
||||
int pack_tile = src_step;
|
||||
int src_point_stride = block_tile * pack_tile;
|
||||
#ifdef ENABLE_ARM64
|
||||
for (int l = 0; l < 6; ++l) {
|
||||
float *src_ptr = src_data + l * C4NUM * block_tile;
|
||||
TRANSPOSE_12x4;
|
||||
}
|
||||
|
||||
for (int c = 0; c < real_c; ++c) {
|
||||
float *src_ptr = src_data + c * block_tile;
|
||||
float *dst_ptr = dst_data + c * block_tile;
|
||||
InputTransform6x6Pack12Channel(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride);
|
||||
}
|
||||
#else
|
||||
for (int l = 0; l < 6; ++l) {
|
||||
float *src = src_data + l * pack_tile * block_tile;
|
||||
// 12 * 4 -> 4 * 12
|
||||
float tmp_mat[pack_tile][block_tile];
|
||||
for (int i = 0; i < block_tile; ++i) {
|
||||
for (int j = 0; j < pack_tile; ++j) {
|
||||
tmp_mat[j][i] = src[i * pack_tile + j];
|
||||
}
|
||||
}
|
||||
memcpy(src, tmp_mat, pack_tile * block_tile * sizeof(float));
|
||||
}
|
||||
|
||||
float src[6];
|
||||
float m[6];
|
||||
for (int c = 0; c < real_c; ++c) {
|
||||
for (int i = 0; i < block_tile; ++i) {
|
||||
int tmp_index = c * block_tile + i;
|
||||
for (int w = 0; w < 6; ++w) {
|
||||
src[w] = src_data[tmp_index + w * src_point_stride];
|
||||
}
|
||||
|
||||
float tmp1 = src[3] - src[1];
|
||||
float tmp2 = src[4] - src[2];
|
||||
m[0] = 4 * src[0] - 5 * src[2] + src[4];
|
||||
m[1] = -4 * (src[1] + src[2]) + (src[3] + src[4]);
|
||||
m[2] = 4 * (src[1] - src[2]) + (src[4] - src[3]);
|
||||
m[3] = 2 * tmp1 + tmp2;
|
||||
m[4] = -2 * tmp1 + tmp2;
|
||||
m[5] = 4 * src[1] - 5 * src[3] + src[5];
|
||||
|
||||
for (int w = 0; w < 6; ++w) {
|
||||
dst_data[tmp_index + w * dst_step] = m[w];
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
|
||||
void InputTransform8x8Unit_block4(const float *src_data, float *dst_data, int src_step, int dst_step) {
|
||||
MS_FLOAT32X4 src[64];
|
||||
|
@ -334,6 +652,232 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step,
|
|||
#endif
|
||||
}
|
||||
|
||||
void InputTransform8x8Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step) {
|
||||
#ifdef ENABLE_ARM64
|
||||
for (int l = 0; l < 8; ++l) {
|
||||
const float *src_ptr = src_data + l * 8 * src_step;
|
||||
float *dst_ptr = dst_data + l * dst_row_step;
|
||||
|
||||
MS_FLOAT32X4 s0 = MS_LDQ_F32(src_ptr + 0 * src_step);
|
||||
MS_FLOAT32X4 s1 = MS_LDQ_F32(src_ptr + 1 * src_step);
|
||||
MS_FLOAT32X4 s2 = MS_LDQ_F32(src_ptr + 2 * src_step);
|
||||
MS_FLOAT32X4 s3 = MS_LDQ_F32(src_ptr + 3 * src_step);
|
||||
MS_FLOAT32X4 s4 = MS_LDQ_F32(src_ptr + 4 * src_step);
|
||||
MS_FLOAT32X4 s5 = MS_LDQ_F32(src_ptr + 5 * src_step);
|
||||
MS_FLOAT32X4 s6 = MS_LDQ_F32(src_ptr + 6 * src_step);
|
||||
MS_FLOAT32X4 s7 = MS_LDQ_F32(src_ptr + 7 * src_step);
|
||||
|
||||
MS_FLOAT32X4 m0 = MS_SUBQ_F32(
|
||||
MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s0, 0.5625), MS_MULQ_N_F32(s2, 3.0625)), MS_MULQ_N_F32(s4, 3.5)), s6);
|
||||
MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(s1, 1.125), MS_MULQ_N_F32(s5, 0.5));
|
||||
MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(s2, 2.25), MS_MULQ_N_F32(s4, 3.25));
|
||||
MS_FLOAT32X4 m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(s3, 1.625)), s6);
|
||||
MS_FLOAT32X4 m2 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(s3, 1.625)), s6);
|
||||
tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(s1, 0.5625), s5);
|
||||
tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(s2, 0.5625), MS_MULQ_N_F32(s4, 2.5));
|
||||
MS_FLOAT32X4 m3 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(s3, 2.5)), s6);
|
||||
MS_FLOAT32X4 m4 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(s3, 2.5)), s6);
|
||||
tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(s1, 0.375), MS_MULQ_N_F32(s5, 1.5));
|
||||
tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(s2, 0.25), MS_MULQ_N_F32(s4, 1.25));
|
||||
MS_FLOAT32X4 m5 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(s3, 1.875)), s6);
|
||||
MS_FLOAT32X4 m6 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(s3, 1.875)), s6);
|
||||
MS_FLOAT32X4 m7 = MS_ADDQ_F32(
|
||||
MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(s1, -0.5625), MS_MULQ_N_F32(s3, 3.0625)), MS_MULQ_N_F32(s5, 3.5)), s7);
|
||||
|
||||
MS_STQ_F32(dst_ptr + 0 * dst_step, m0);
|
||||
MS_STQ_F32(dst_ptr + 1 * dst_step, m1);
|
||||
MS_STQ_F32(dst_ptr + 2 * dst_step, m2);
|
||||
MS_STQ_F32(dst_ptr + 3 * dst_step, m3);
|
||||
MS_STQ_F32(dst_ptr + 4 * dst_step, m4);
|
||||
MS_STQ_F32(dst_ptr + 5 * dst_step, m5);
|
||||
MS_STQ_F32(dst_ptr + 6 * dst_step, m6);
|
||||
MS_STQ_F32(dst_ptr + 7 * dst_step, m7);
|
||||
}
|
||||
#else
|
||||
float src[8];
|
||||
float m[8];
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
for (int l = 0; l < 8; ++l) {
|
||||
for (int w = 0; w < 8; ++w) {
|
||||
int tmp_index = l * 8 + w;
|
||||
src[w] = src_data[i + tmp_index * src_step];
|
||||
}
|
||||
m[0] = 0.5625f * src[0] - 3.0625f * src[2] + 3.5f * src[4] - src[6];
|
||||
float tmp1 = 1.125f * src[1] + 0.5f * src[5];
|
||||
float tmp2 = 2.25f * src[2] - 3.25f * src[4];
|
||||
m[1] = tmp1 + tmp2 - 1.625f * src[3] + src[6];
|
||||
m[2] = tmp2 - tmp1 + 1.625f * src[3] + src[6];
|
||||
tmp1 = 0.5625f * src[1] + src[5];
|
||||
tmp2 = 0.5625f * src[2] - 2.5f * src[4];
|
||||
m[3] = tmp1 + tmp2 - 2.5f * src[3] + src[6];
|
||||
m[4] = tmp2 - tmp1 + 2.5f * src[3] + src[6];
|
||||
tmp1 = 0.375f * src[1] + 1.5f * src[5];
|
||||
tmp2 = 0.25f * src[2] - 1.25f * src[4];
|
||||
m[5] = tmp1 + tmp2 - 1.875f * src[3] + src[6];
|
||||
m[6] = tmp2 - tmp1 + 1.875f * src[3] + src[6];
|
||||
m[7] = -0.5625f * src[1] + 3.0625f * src[3] - 3.5f * src[5] + src[7];
|
||||
|
||||
float *dst = dst_data + l * dst_row_step;
|
||||
for (int w = 0; w < 8; ++w) {
|
||||
dst[i + w * dst_step] = m[w];
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void InputTransform8x8Pack12Channel(float *src_ptr, float *dst_ptr, int dst_step, int pack_tile, int src_point_stride) {
|
||||
LOAD_LINE_DATA(0);
|
||||
LOAD_LINE_DATA(1);
|
||||
LOAD_LINE_DATA(2);
|
||||
LOAD_LINE_DATA(3);
|
||||
LOAD_LINE_DATA(4);
|
||||
LOAD_LINE_DATA(5);
|
||||
LOAD_LINE_DATA(6);
|
||||
LOAD_LINE_DATA(7);
|
||||
|
||||
MS_FLOAT32X4 m0 = MS_SUBQ_F32(
|
||||
MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s00, 0.5625), MS_MULQ_N_F32(s20, 3.0625)), MS_MULQ_N_F32(s40, 3.5)), s60);
|
||||
MS_FLOAT32X4 m1 = MS_SUBQ_F32(
|
||||
MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s01, 0.5625), MS_MULQ_N_F32(s21, 3.0625)), MS_MULQ_N_F32(s41, 3.5)), s61);
|
||||
MS_FLOAT32X4 m2 = MS_SUBQ_F32(
|
||||
MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s02, 0.5625), MS_MULQ_N_F32(s22, 3.0625)), MS_MULQ_N_F32(s42, 3.5)), s62);
|
||||
MS_STQ_F32(dst_ptr + 0 * dst_step + 0 * pack_tile, m0);
|
||||
MS_STQ_F32(dst_ptr + 0 * dst_step + 1 * pack_tile, m1);
|
||||
MS_STQ_F32(dst_ptr + 0 * dst_step + 2 * pack_tile, m2);
|
||||
|
||||
MS_FLOAT32X4 tmp10 = MS_ADDQ_F32(MS_MULQ_N_F32(s10, 1.125), MS_MULQ_N_F32(s50, 0.5));
|
||||
MS_FLOAT32X4 tmp11 = MS_ADDQ_F32(MS_MULQ_N_F32(s11, 1.125), MS_MULQ_N_F32(s51, 0.5));
|
||||
MS_FLOAT32X4 tmp12 = MS_ADDQ_F32(MS_MULQ_N_F32(s12, 1.125), MS_MULQ_N_F32(s52, 0.5));
|
||||
MS_FLOAT32X4 tmp20 = MS_SUBQ_F32(MS_MULQ_N_F32(s20, 2.25), MS_MULQ_N_F32(s40, 3.25));
|
||||
MS_FLOAT32X4 tmp21 = MS_SUBQ_F32(MS_MULQ_N_F32(s21, 2.25), MS_MULQ_N_F32(s41, 3.25));
|
||||
MS_FLOAT32X4 tmp22 = MS_SUBQ_F32(MS_MULQ_N_F32(s22, 2.25), MS_MULQ_N_F32(s42, 3.25));
|
||||
m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp10, tmp20), MS_MULQ_N_F32(s30, 1.625)), s60);
|
||||
m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp11, tmp21), MS_MULQ_N_F32(s31, 1.625)), s61);
|
||||
m2 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp12, tmp22), MS_MULQ_N_F32(s32, 1.625)), s62);
|
||||
MS_STQ_F32(dst_ptr + 1 * dst_step + 0 * pack_tile, m0);
|
||||
MS_STQ_F32(dst_ptr + 1 * dst_step + 1 * pack_tile, m1);
|
||||
MS_STQ_F32(dst_ptr + 1 * dst_step + 2 * pack_tile, m2);
|
||||
|
||||
m0 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp20, tmp10), MS_MULQ_N_F32(s30, 1.625)), s60);
|
||||
m1 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp21, tmp11), MS_MULQ_N_F32(s31, 1.625)), s61);
|
||||
m2 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp22, tmp12), MS_MULQ_N_F32(s32, 1.625)), s62);
|
||||
MS_STQ_F32(dst_ptr + 2 * dst_step + 0 * pack_tile, m0);
|
||||
MS_STQ_F32(dst_ptr + 2 * dst_step + 1 * pack_tile, m1);
|
||||
MS_STQ_F32(dst_ptr + 2 * dst_step + 2 * pack_tile, m2);
|
||||
|
||||
tmp10 = MS_ADDQ_F32(MS_MULQ_N_F32(s10, 0.5625), s50);
|
||||
tmp11 = MS_ADDQ_F32(MS_MULQ_N_F32(s11, 0.5625), s51);
|
||||
tmp12 = MS_ADDQ_F32(MS_MULQ_N_F32(s12, 0.5625), s52);
|
||||
tmp20 = MS_SUBQ_F32(MS_MULQ_N_F32(s20, 0.5625), MS_MULQ_N_F32(s40, 2.5));
|
||||
tmp21 = MS_SUBQ_F32(MS_MULQ_N_F32(s21, 0.5625), MS_MULQ_N_F32(s41, 2.5));
|
||||
tmp22 = MS_SUBQ_F32(MS_MULQ_N_F32(s22, 0.5625), MS_MULQ_N_F32(s42, 2.5));
|
||||
m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp10, tmp20), MS_MULQ_N_F32(s30, 2.5)), s60);
|
||||
m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp11, tmp21), MS_MULQ_N_F32(s31, 2.5)), s61);
|
||||
m2 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp12, tmp22), MS_MULQ_N_F32(s32, 2.5)), s62);
|
||||
MS_STQ_F32(dst_ptr + 3 * dst_step + 0 * pack_tile, m0);
|
||||
MS_STQ_F32(dst_ptr + 3 * dst_step + 1 * pack_tile, m1);
|
||||
MS_STQ_F32(dst_ptr + 3 * dst_step + 2 * pack_tile, m2);
|
||||
|
||||
m0 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp20, tmp10), MS_MULQ_N_F32(s30, 2.5)), s60);
|
||||
m1 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp21, tmp11), MS_MULQ_N_F32(s31, 2.5)), s61);
|
||||
m2 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp22, tmp12), MS_MULQ_N_F32(s32, 2.5)), s62);
|
||||
MS_STQ_F32(dst_ptr + 4 * dst_step + 0 * pack_tile, m0);
|
||||
MS_STQ_F32(dst_ptr + 4 * dst_step + 1 * pack_tile, m1);
|
||||
MS_STQ_F32(dst_ptr + 4 * dst_step + 2 * pack_tile, m2);
|
||||
|
||||
tmp10 = MS_ADDQ_F32(MS_MULQ_N_F32(s10, 0.375), MS_MULQ_N_F32(s50, 1.5));
|
||||
tmp11 = MS_ADDQ_F32(MS_MULQ_N_F32(s11, 0.375), MS_MULQ_N_F32(s51, 1.5));
|
||||
tmp12 = MS_ADDQ_F32(MS_MULQ_N_F32(s12, 0.375), MS_MULQ_N_F32(s52, 1.5));
|
||||
tmp20 = MS_SUBQ_F32(MS_MULQ_N_F32(s20, 0.25), MS_MULQ_N_F32(s40, 1.25));
|
||||
tmp21 = MS_SUBQ_F32(MS_MULQ_N_F32(s21, 0.25), MS_MULQ_N_F32(s41, 1.25));
|
||||
tmp22 = MS_SUBQ_F32(MS_MULQ_N_F32(s22, 0.25), MS_MULQ_N_F32(s42, 1.25));
|
||||
m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp10, tmp20), MS_MULQ_N_F32(s30, 1.875)), s60);
|
||||
m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp11, tmp21), MS_MULQ_N_F32(s31, 1.875)), s61);
|
||||
m2 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp12, tmp22), MS_MULQ_N_F32(s32, 1.875)), s62);
|
||||
MS_STQ_F32(dst_ptr + 5 * dst_step + 0 * pack_tile, m0);
|
||||
MS_STQ_F32(dst_ptr + 5 * dst_step + 1 * pack_tile, m1);
|
||||
MS_STQ_F32(dst_ptr + 5 * dst_step + 2 * pack_tile, m2);
|
||||
|
||||
m0 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp20, tmp10), MS_MULQ_N_F32(s30, 1.875)), s60);
|
||||
m1 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp21, tmp11), MS_MULQ_N_F32(s31, 1.875)), s61);
|
||||
m2 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp22, tmp12), MS_MULQ_N_F32(s32, 1.875)), s62);
|
||||
MS_STQ_F32(dst_ptr + 6 * dst_step + 0 * pack_tile, m0);
|
||||
MS_STQ_F32(dst_ptr + 6 * dst_step + 1 * pack_tile, m1);
|
||||
MS_STQ_F32(dst_ptr + 6 * dst_step + 2 * pack_tile, m2);
|
||||
|
||||
m0 = MS_ADDQ_F32(
|
||||
MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(s10, -0.5625), MS_MULQ_N_F32(s30, 3.0625)), MS_MULQ_N_F32(s50, 3.5)), s70);
|
||||
m1 = MS_ADDQ_F32(
|
||||
MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(s11, -0.5625), MS_MULQ_N_F32(s31, 3.0625)), MS_MULQ_N_F32(s51, 3.5)), s71);
|
||||
m2 = MS_ADDQ_F32(
|
||||
MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(s12, -0.5625), MS_MULQ_N_F32(s32, 3.0625)), MS_MULQ_N_F32(s52, 3.5)), s72);
|
||||
MS_STQ_F32(dst_ptr + 7 * dst_step + 0 * pack_tile, m0);
|
||||
MS_STQ_F32(dst_ptr + 7 * dst_step + 1 * pack_tile, m1);
|
||||
MS_STQ_F32(dst_ptr + 7 * dst_step + 2 * pack_tile, m2);
|
||||
}
|
||||
#endif
|
||||
|
||||
void InputTransform8x8Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c) {
|
||||
int block_tile = 12;
|
||||
int pack_tile = src_step;
|
||||
int src_point_stride = block_tile * pack_tile;
|
||||
#ifdef ENABLE_ARM64
|
||||
for (int l = 0; l < 8; ++l) {
|
||||
float *src_ptr = src_data + l * C4NUM * block_tile;
|
||||
TRANSPOSE_12x4;
|
||||
}
|
||||
|
||||
for (int c = 0; c < real_c; ++c) {
|
||||
float *src_ptr = src_data + c * block_tile;
|
||||
float *dst_ptr = dst_data + c * block_tile;
|
||||
InputTransform8x8Pack12Channel(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride);
|
||||
}
|
||||
#else
|
||||
for (int l = 0; l < 8; ++l) {
|
||||
float *src = src_data + l * pack_tile * block_tile;
|
||||
// 12 * 4 -> 4 * 12
|
||||
float tmp_mat[pack_tile][block_tile];
|
||||
for (int i = 0; i < block_tile; ++i) {
|
||||
for (int j = 0; j < pack_tile; ++j) {
|
||||
tmp_mat[j][i] = src[i * pack_tile + j];
|
||||
}
|
||||
}
|
||||
memcpy(src, tmp_mat, pack_tile * block_tile * sizeof(float));
|
||||
}
|
||||
|
||||
float src[8];
|
||||
float m[8];
|
||||
for (int c = 0; c < real_c; ++c) {
|
||||
for (int i = 0; i < block_tile; ++i) {
|
||||
int tmp_index = c * block_tile + i;
|
||||
for (int w = 0; w < 8; ++w) {
|
||||
src[w] = src_data[tmp_index + w * src_point_stride];
|
||||
}
|
||||
m[0] = 0.5625f * src[0] - 3.0625f * src[2] + 3.5f * src[4] - src[6];
|
||||
float tmp1 = 1.125f * src[1] + 0.5f * src[5];
|
||||
float tmp2 = 2.25f * src[2] - 3.25f * src[4];
|
||||
m[1] = tmp1 + tmp2 - 1.625f * src[3] + src[6];
|
||||
m[2] = tmp2 - tmp1 + 1.625f * src[3] + src[6];
|
||||
tmp1 = 0.5625f * src[1] + src[5];
|
||||
tmp2 = 0.5625f * src[2] - 2.5f * src[4];
|
||||
m[3] = tmp1 + tmp2 - 2.5f * src[3] + src[6];
|
||||
m[4] = tmp2 - tmp1 + 2.5f * src[3] + src[6];
|
||||
tmp1 = 0.375f * src[1] + 1.5f * src[5];
|
||||
tmp2 = 0.25f * src[2] - 1.25f * src[4];
|
||||
m[5] = tmp1 + tmp2 - 1.875f * src[3] + src[6];
|
||||
m[6] = tmp2 - tmp1 + 1.875f * src[3] + src[6];
|
||||
m[7] = -0.5625f * src[1] + 3.0625f * src[3] - 3.5f * src[5] + src[7];
|
||||
|
||||
for (int w = 0; w < 8; ++w) {
|
||||
dst_data[tmp_index + w * dst_step] = m[w];
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
OutputTransFunc GetOutputTransFunc(int input_unit, int output_unit, ActType act_type) {
|
||||
if (!CheckWinogradInputOutputUnit(input_unit, output_unit)) {
|
||||
return NULL;
|
||||
|
|
|
@ -28,9 +28,21 @@ extern "C" {
|
|||
#endif
|
||||
typedef void (*InputTransFunc)(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
|
||||
|
||||
typedef void (*InputTransStepFunc)(const float *src_data, float *dst_data, int src_step, int dst_step,
|
||||
int dst_row_step);
|
||||
|
||||
typedef void (*InputTransPackFunc)(float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
|
||||
|
||||
typedef void (*OutputTransFunc)(const float *src_data, float *dst_data, const float *bias_data, int src_step,
|
||||
int dst_step, int out_c, int r_w, int r_h, int r_c);
|
||||
|
||||
typedef struct TransFuncList {
|
||||
InputTransFunc in_func_;
|
||||
InputTransStepFunc in_step_func_;
|
||||
InputTransPackFunc in_pack_func_;
|
||||
OutputTransFunc out_func_;
|
||||
} TransFuncList;
|
||||
|
||||
#define Load16Data \
|
||||
src[0] = MS_LDQ_F32(src_data + 0 * src_step); \
|
||||
src[1] = MS_LDQ_F32(src_data + 1 * src_step); \
|
||||
|
@ -153,14 +165,66 @@ typedef void (*OutputTransFunc)(const float *src_data, float *dst_data, const fl
|
|||
src[62] = MS_LDQ_F32(src_data + 62 * src_step); \
|
||||
src[63] = MS_LDQ_F32(src_data + 63 * src_step);
|
||||
|
||||
#define LOAD_LINE_DATA(line) \
|
||||
MS_FLOAT32X4 s##line##0 = MS_LDQ_F32(src_ptr + line * src_point_stride + 0 * pack_tile); \
|
||||
MS_FLOAT32X4 s##line##1 = MS_LDQ_F32(src_ptr + line * src_point_stride + 1 * pack_tile); \
|
||||
MS_FLOAT32X4 s##line##2 = MS_LDQ_F32(src_ptr + line * src_point_stride + 2 * pack_tile);
|
||||
|
||||
#define TRANSPOSE_12x4 \
|
||||
MS_FLOAT32X4 s0 = MS_LDQ_F32(src_ptr + 0 * pack_tile); \
|
||||
MS_FLOAT32X4 s3 = MS_LDQ_F32(src_ptr + 1 * pack_tile); \
|
||||
MS_FLOAT32X4 s6 = MS_LDQ_F32(src_ptr + 2 * pack_tile); \
|
||||
MS_FLOAT32X4 s9 = MS_LDQ_F32(src_ptr + 3 * pack_tile); \
|
||||
MS_FLOAT32X4 s1 = MS_LDQ_F32(src_ptr + 4 * pack_tile); \
|
||||
MS_FLOAT32X4 s4 = MS_LDQ_F32(src_ptr + 5 * pack_tile); \
|
||||
MS_FLOAT32X4 s7 = MS_LDQ_F32(src_ptr + 6 * pack_tile); \
|
||||
MS_FLOAT32X4 s10 = MS_LDQ_F32(src_ptr + 7 * pack_tile); \
|
||||
MS_FLOAT32X4 s2 = MS_LDQ_F32(src_ptr + 8 * pack_tile); \
|
||||
MS_FLOAT32X4 s5 = MS_LDQ_F32(src_ptr + 9 * pack_tile); \
|
||||
MS_FLOAT32X4 s8 = MS_LDQ_F32(src_ptr + 10 * pack_tile); \
|
||||
MS_FLOAT32X4 s11 = MS_LDQ_F32(src_ptr + 11 * pack_tile); \
|
||||
transpose4(&s0, &s3, &s6, &s9); \
|
||||
transpose4(&s1, &s4, &s7, &s10); \
|
||||
transpose4(&s2, &s5, &s8, &s11); \
|
||||
MS_STQ_F32(src_ptr + 0 * pack_tile, s0); \
|
||||
MS_STQ_F32(src_ptr + 1 * pack_tile, s1); \
|
||||
MS_STQ_F32(src_ptr + 2 * pack_tile, s2); \
|
||||
MS_STQ_F32(src_ptr + 3 * pack_tile, s3); \
|
||||
MS_STQ_F32(src_ptr + 4 * pack_tile, s4); \
|
||||
MS_STQ_F32(src_ptr + 5 * pack_tile, s5); \
|
||||
MS_STQ_F32(src_ptr + 6 * pack_tile, s6); \
|
||||
MS_STQ_F32(src_ptr + 7 * pack_tile, s7); \
|
||||
MS_STQ_F32(src_ptr + 8 * pack_tile, s8); \
|
||||
MS_STQ_F32(src_ptr + 9 * pack_tile, s9); \
|
||||
MS_STQ_F32(src_ptr + 10 * pack_tile, s10); \
|
||||
MS_STQ_F32(src_ptr + 11 * pack_tile, s11);
|
||||
|
||||
InputTransFunc GetInputTransFunc(int input_unit);
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
InputTransStepFunc GetInputTransStepFunc(int input_unit);
|
||||
|
||||
InputTransPackFunc GetInputTransPackFunc(int input_unit);
|
||||
#endif
|
||||
|
||||
void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
|
||||
|
||||
void InputTransform4x4Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step);
|
||||
|
||||
void InputTransform4x4Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
|
||||
|
||||
void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
|
||||
|
||||
void InputTransform6x6Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step);
|
||||
|
||||
void InputTransform6x6Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
|
||||
|
||||
void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
|
||||
|
||||
void InputTransform8x8Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step);
|
||||
|
||||
void InputTransform8x8Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
|
||||
|
||||
OutputTransFunc GetOutputTransFunc(int input_unit, int output_unit, ActType act_type);
|
||||
|
||||
#define Store4Data \
|
||||
|
|
|
@ -172,10 +172,10 @@ int ConvolutionWinogradFP32Coder::InitWeightBias() {
|
|||
}
|
||||
|
||||
int ConvolutionWinogradFP32Coder::ConfigInputOutput() {
|
||||
in_func_ = GetInputTransFunc(input_unit_);
|
||||
MS_CHECK_TRUE(!in_func_.empty(), "Get input_trans_func failed.");
|
||||
out_func_ = GetOutputTransFunc(input_unit_, output_unit_, conv_param_->act_type_);
|
||||
MS_CHECK_TRUE(!out_func_.empty(), "Get output_trans_func_ failed.");
|
||||
trans_func_str_.in_func_ = GetInputTransFunc(input_unit_);
|
||||
MS_CHECK_TRUE(!trans_func_str_.in_func_.empty(), "Get input_trans_func failed.");
|
||||
trans_func_str_.out_func_ = GetOutputTransFunc(input_unit_, output_unit_, conv_param_->act_type_);
|
||||
MS_CHECK_TRUE(!trans_func_str_.out_func_.empty(), "Get output_trans_func_ failed.");
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -269,9 +269,10 @@ int ConvolutionWinogradFP32Coder::DoCode(CoderContext *const context) {
|
|||
<< allocator_->GetRuntimeAddr(gemm_out_) << ", " << allocator_->GetRuntimeAddr(tmp_data_) << ", "
|
||||
<< allocator_->GetRuntimeAddr(col_buffer_) << "};\n";
|
||||
code.CodeStruct("conv_parameter", *conv_param_);
|
||||
code.CodeStruct("trans_func", trans_func_str_);
|
||||
// code operator func
|
||||
code.CodeFunction("ConvWinogardFp32", input_tensor_, trans_weight_, new_bias_, output_tensor_,
|
||||
"tmp_buffer_address_list", kDefaultTaskId, "&conv_parameter", in_func_, out_func_);
|
||||
"tmp_buffer_address_list", kDefaultTaskId, "&conv_parameter", "trans_func");
|
||||
context->AppendCode(code.str());
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <vector>
|
||||
#include "coder/opcoders/base/conv2d_base_coder.h"
|
||||
#include "nnacl/conv_parameter.h"
|
||||
#include "wrapper/fp32/conv_winograd_fp32_wrapper.h"
|
||||
|
||||
namespace mindspore::lite::micro::nnacl {
|
||||
class ConvolutionWinogradFP32Coder : public Conv2DBaseCoder {
|
||||
|
@ -68,8 +69,7 @@ class ConvolutionWinogradFP32Coder : public Conv2DBaseCoder {
|
|||
float *gemm_out_{nullptr};
|
||||
float *col_buffer_{nullptr};
|
||||
|
||||
std::string in_func_;
|
||||
std::string out_func_;
|
||||
TransFuncStr trans_func_str_;
|
||||
};
|
||||
} // namespace mindspore::lite::micro::nnacl
|
||||
#endif // MINDSPORE_LITE_MICRO_CODER_OPCODERS_FP32_CONVOLUTION_WINOGRAD_FP32_CODER_H_
|
||||
|
|
|
@ -157,4 +157,8 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const SpliceWrappe
|
|||
splice_param.src_to_dst_row_offset);
|
||||
}
|
||||
|
||||
void NNaclFp32Serializer::CodeStruct(const std::string &name, const TransFuncStr trans_func_str) {
|
||||
CodeBaseStruct("TransFuncList", name, trans_func_str.in_func_, nullptr, nullptr, trans_func_str.out_func_);
|
||||
}
|
||||
|
||||
} // namespace mindspore::lite::micro::nnacl
|
||||
|
|
|
@ -36,6 +36,7 @@
|
|||
#include "nnacl/fp32/strided_slice_fp32.h"
|
||||
#include "wrapper/fp32/arithmetic_fp32_wrapper.h"
|
||||
#include "wrapper/base/affine_wrapper.h"
|
||||
#include "wrapper/fp32/conv_winograd_fp32_wrapper.h"
|
||||
|
||||
namespace mindspore::lite::micro::nnacl {
|
||||
|
||||
|
@ -60,6 +61,7 @@ class NNaclFp32Serializer : public Serializer {
|
|||
void CodeStruct(const std::string &name, const StridedSliceParameter &strided_slice_parameter);
|
||||
void CodeStruct(const std::string &name, const ArithmeticWrapperInfo &arithmetic_wrapper_info);
|
||||
void CodeStruct(const std::string &name, const SpliceWrapperParam &splice_param);
|
||||
void CodeStruct(const std::string &name, const TransFuncStr trans_func_str);
|
||||
};
|
||||
|
||||
} // namespace mindspore::lite::micro::nnacl
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_MICRO_CODER_WRAPPER_FP32_CONV_WINOGRAD_FP32_WRAPPER_H_
|
||||
#define MINDSPORE_LITE_MICRO_CODER_WRAPPER_FP32_CONV_WINOGRAD_FP32_WRAPPER_H_
|
||||
#include <string>
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
typedef struct TransFuncStr {
|
||||
std::string in_func_;
|
||||
std::string out_func_;
|
||||
} TransFuncStr;
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif // MINDSPORE_LITE_MICRO_CODER_WRAPPER_FP32_CONV_WINOGRAD_FP32_WRAPPER_H_
|
|
@ -119,21 +119,40 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
|
||||
opt_input_trans_ = reinterpret_cast<float16_t *>(
|
||||
ctx_->allocator->Malloc(thread_count_ * row_tile_ * input_unit_ * input_unit_ *
|
||||
UP_ROUND(conv_param_->input_channel_, C8NUM) * sizeof(float16_t)));
|
||||
if (opt_input_trans_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc opt_input_trans_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
tmp_buffer_address_list_[0] = trans_input_;
|
||||
tmp_buffer_address_list_[1] = gemm_out_;
|
||||
tmp_buffer_address_list_[2] = tmp_data_;
|
||||
tmp_buffer_address_list_[3] = col_buffer_;
|
||||
tmp_buffer_address_list_[4] = opt_input_trans_;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradFP16CPUKernel::ConfigInputOutput() {
|
||||
in_func_ = GetInputTransFp16Func(input_unit_);
|
||||
if (in_func_ == nullptr) {
|
||||
trans_func_.in_func_ = GetInputTransFp16Func(input_unit_);
|
||||
if (trans_func_.in_func_ == nullptr) {
|
||||
MS_LOG(ERROR) << "in_func_ is null.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
out_func_ = GetOutputTransFp16Func(input_unit_, output_unit_, conv_param_->act_type_);
|
||||
if (out_func_ == nullptr) {
|
||||
#ifdef ENABLE_ARM64
|
||||
trans_func_.in_step_func_ = GetInputTransStepFp16Func(input_unit_);
|
||||
if (trans_func_.in_step_func_ == nullptr) {
|
||||
MS_LOG(DEBUG) << "in_step_func_ is null.";
|
||||
}
|
||||
trans_func_.in_pack_func_ = GetInputTransPackFp16Func(input_unit_);
|
||||
if (trans_func_.in_pack_func_ == nullptr) {
|
||||
MS_LOG(DEBUG) << "in_pack_func_ is null.";
|
||||
}
|
||||
#endif
|
||||
trans_func_.out_func_ = GetOutputTransFp16Func(input_unit_, output_unit_, conv_param_->act_type_);
|
||||
if (trans_func_.out_func_ == nullptr) {
|
||||
MS_LOG(ERROR) << "out_func_ is null.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -219,7 +238,7 @@ int ConvolutionWinogradFP16CPUKernel::RunImpl(int task_id) {
|
|||
}
|
||||
ConvWinogardFp16(input_ptr, reinterpret_cast<float16_t *>(packed_weight_),
|
||||
reinterpret_cast<const float16_t *>(bias_data_), output_ptr, tmp_buffer_address_list_, task_id,
|
||||
conv_param_, in_func_, out_func_);
|
||||
conv_param_, trans_func_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -65,6 +65,10 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseCPUKernel {
|
|||
ctx_->allocator->Free(col_buffer_);
|
||||
col_buffer_ = nullptr;
|
||||
}
|
||||
if (opt_input_trans_ != nullptr) {
|
||||
ctx_->allocator->Free(opt_input_trans_);
|
||||
opt_input_trans_ = nullptr;
|
||||
}
|
||||
}
|
||||
int FilterWeight();
|
||||
int kernel_unit_ = 0;
|
||||
|
@ -74,11 +78,11 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseCPUKernel {
|
|||
float16_t *trans_input_ = nullptr;
|
||||
float16_t *gemm_out_ = nullptr;
|
||||
float16_t *col_buffer_ = nullptr;
|
||||
float16_t *opt_input_trans_ = nullptr;
|
||||
float matrix_g_[64];
|
||||
float matrix_gt_[64];
|
||||
TmpBufferAddressFp16 tmp_buffer_address_list_[4] = {0};
|
||||
InputTransFp16Func in_func_ = nullptr;
|
||||
OutputTransFp16Func out_func_ = nullptr;
|
||||
TmpBufferAddressFp16 tmp_buffer_address_list_[5] = {0};
|
||||
TransFp16FuncList trans_func_;
|
||||
int col_tile_ = 0;
|
||||
int row_tile_ = 0;
|
||||
};
|
||||
|
|
|
@ -69,21 +69,40 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
|
||||
opt_input_trans_ = reinterpret_cast<float *>(
|
||||
ctx_->allocator->Malloc(thread_count_ * tile_num_ * input_unit_ * input_unit_ *
|
||||
UP_ROUND(conv_param_->input_channel_, tmp_data_tile_) * sizeof(float)));
|
||||
if (opt_input_trans_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc opt_input_trans_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
tmp_buffer_address_list_[0] = trans_input_;
|
||||
tmp_buffer_address_list_[1] = gemm_out_;
|
||||
tmp_buffer_address_list_[2] = tmp_data_;
|
||||
tmp_buffer_address_list_[3] = col_buffer_;
|
||||
tmp_buffer_address_list_[4] = opt_input_trans_;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradCPUKernel::ConfigInputOutput() {
|
||||
in_func_ = GetInputTransFunc(input_unit_);
|
||||
if (in_func_ == nullptr) {
|
||||
trans_func_.in_func_ = GetInputTransFunc(input_unit_);
|
||||
if (trans_func_.in_func_ == nullptr) {
|
||||
MS_LOG(ERROR) << "in_func_ is null.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
out_func_ = GetOutputTransFunc(input_unit_, output_unit_, conv_param_->act_type_);
|
||||
if (out_func_ == nullptr) {
|
||||
#ifdef ENABLE_ARM64
|
||||
trans_func_.in_step_func_ = GetInputTransStepFunc(input_unit_);
|
||||
if (trans_func_.in_step_func_ == nullptr) {
|
||||
MS_LOG(DEBUG) << "in_step_func_ is null.";
|
||||
}
|
||||
trans_func_.in_pack_func_ = GetInputTransPackFunc(input_unit_);
|
||||
if (trans_func_.in_pack_func_ == nullptr) {
|
||||
MS_LOG(DEBUG) << "in_pack_func_ is null.";
|
||||
}
|
||||
#endif
|
||||
trans_func_.out_func_ = GetOutputTransFunc(input_unit_, output_unit_, conv_param_->act_type_);
|
||||
if (trans_func_.out_func_ == nullptr) {
|
||||
MS_LOG(ERROR) << "out_func_ is null.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -152,7 +171,7 @@ int ConvolutionWinogradCPUKernel::RunImpl(int task_id) {
|
|||
CHECK_NULL_RETURN(output_data);
|
||||
ConvWinogardFp32(ori_input_data, reinterpret_cast<float *>(packed_weight_),
|
||||
reinterpret_cast<const float *>(bias_data_), output_data, tmp_buffer_address_list_, task_id,
|
||||
conv_param_, in_func_, out_func_);
|
||||
conv_param_, trans_func_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -62,6 +62,10 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel {
|
|||
ctx_->allocator->Free(col_buffer_);
|
||||
col_buffer_ = nullptr;
|
||||
}
|
||||
if (opt_input_trans_ != nullptr) {
|
||||
ctx_->allocator->Free(opt_input_trans_);
|
||||
opt_input_trans_ = nullptr;
|
||||
}
|
||||
}
|
||||
int kernel_unit_{0};
|
||||
int input_unit_{0};
|
||||
|
@ -73,11 +77,11 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel {
|
|||
float *trans_input_ = nullptr;
|
||||
float *gemm_out_ = nullptr;
|
||||
float *col_buffer_ = nullptr;
|
||||
float *opt_input_trans_ = nullptr;
|
||||
float matrix_g_[64];
|
||||
float matrix_gt_[64];
|
||||
TmpBufferAddress tmp_buffer_address_list_[4] = {nullptr};
|
||||
InputTransFunc in_func_ = nullptr;
|
||||
OutputTransFunc out_func_ = nullptr;
|
||||
TmpBufferAddress tmp_buffer_address_list_[5] = {nullptr};
|
||||
TransFuncList trans_func_;
|
||||
};
|
||||
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -3,13 +3,13 @@
|
|||
# [second column]:accuracy limit for float16 in arm64 device
|
||||
hdc_age_medium 5.9
|
||||
beard 2
|
||||
emotion 60
|
||||
gender_res_large_deploy 0.1
|
||||
emotion 216
|
||||
gender_res_large_deploy 2
|
||||
glasses 4
|
||||
hat 2.5
|
||||
ml_bank_detect_0312_tmp 20
|
||||
ml_face_div_parsing 8
|
||||
ml_hardware_eyeclose 0.1
|
||||
ml_hardware_eyeclose 0.5
|
||||
ml_ocr_detect_20200305 10
|
||||
Mnet6_0312_extract_pay 15
|
||||
pose_3d 90
|
||||
|
@ -37,7 +37,7 @@ ml_ocr_sfz_add_final_0325 0.1
|
|||
ml_hardware_pose 2
|
||||
ml_bank_recog 0.1
|
||||
2012_ATLANTA_10class_20190131_v4.0 12
|
||||
mnet 12
|
||||
mnet 13
|
||||
recognition 10.8
|
||||
ml_face_landmark 1
|
||||
model_hebing_3branch 40
|
||||
|
@ -71,31 +71,31 @@ ml_location_scene_division 8
|
|||
ml_tabel_recog 0.1
|
||||
ml_text_division 12
|
||||
# Further analysis in the future to model ml_video_edit_Mnet
|
||||
ml_video_edit_Mnet 11.5
|
||||
ml_video_edit_Mnet 15.5
|
||||
ml_video_edit_hairSeg_have_imageProcessLayer_interpTo145 0.5
|
||||
hdc_contour_pose_128 0.5
|
||||
hdc_emotion 0.5
|
||||
hdc_fivembnet 1
|
||||
hdc_isface 0.5
|
||||
hdc_mobilenetface 11.5 # small output causes big bias
|
||||
hdc_retinaface 14
|
||||
hdc_retinaface 15
|
||||
hdc_resnet 7
|
||||
ml_video_edit_detect_20211111 2.5
|
||||
ml_video_edit_detect_20211111 3
|
||||
ml_video_edit_hairSeg_have_imageProcessLayer_interpTo145_20210121 0.5
|
||||
ml_video_edit_have_imageProcessLayer_interpTo145_20201015 0.5
|
||||
ml_video_edit_MnetN367_extract_1010_pay 1
|
||||
ml_video_edit_person_divison_pic 0.5
|
||||
ml_video_edit_reid 1
|
||||
ml_video_edit_v10_best_model_nomean_20200723 5.1
|
||||
ml_video_edit_v10_best_model_nomean_20200723 6
|
||||
ml_video_edit_img_segment 3
|
||||
ml_video_edit_video_segment_gauss_adaptis_part1 5
|
||||
# When the input range is [-1,1], the precision is poor, and the output value is very small (10e-5). If the input range is adjusted to [0,255], the precision will decrease to 15.5415%, and the rest is cumulative error.
|
||||
ml_handpose 175
|
||||
ml_handpose 177
|
||||
hdc_Face_Aesthetic_MTI_Aesthetic 0.5
|
||||
ml_face_compare 8.7
|
||||
ml_face_tracking 2.5
|
||||
ml_face_beard 0.6
|
||||
ml_face_age 3.7
|
||||
ml_face_beard 1
|
||||
ml_face_age 4
|
||||
ml_face_pose 1
|
||||
ml_face_isface 0.5
|
||||
ml_face_glasses 3.4
|
||||
|
@ -108,13 +108,13 @@ ml_Hand_deploy 4
|
|||
ml_hand_3d_detection 12
|
||||
ml_hand_3d_regression 5.4
|
||||
# ml_ARengine23_bodypose: The difference of output node divided by a very small value leads to a large error
|
||||
ml_ARengine23_bodypose 56
|
||||
ml_ARengine23_bodypose 57
|
||||
ml_ocr_bank_card_detection_inception_tmp 20
|
||||
ml_ocr_bank_card_recognition_fcny 0.5
|
||||
hiai_cv_aestheticsEngineModel_osp 1.6
|
||||
hiai_cv_aestheticsEngineModel_osp 3.5
|
||||
ml_face_hat 2.2
|
||||
bank_card_recognition_fcny 17
|
||||
bank_card_detection_inception_tmp 12
|
||||
bank_card_recognition_fcny 19
|
||||
bank_card_detection_inception_tmp 13.5
|
||||
ml_ocr_identify_card_fcny 0.5
|
||||
ml_ocr_identify_card_detect_tmp 2
|
||||
identify_card_detect_tmp 0.5
|
||||
|
@ -123,18 +123,18 @@ ml_2012_ocr_rec_caffe 0.5
|
|||
ml_lable_model_hebing_device 3
|
||||
ml_face_sex 0.7
|
||||
# ml_face_mnet: The precision problem caused by cumulative error.
|
||||
ml_face_mnet 12
|
||||
ml_face_mnet 13
|
||||
ml_segmentation_atlanta_1 0.5
|
||||
bolt_deploy_color-server 0.5
|
||||
ml_face_emotion 0.5
|
||||
hdc_ocr_recog_horizontal 0.5
|
||||
# The outputs of two Heatmap_depth models have small value
|
||||
ml_Heatmap_depth_240180;2 10
|
||||
ml_Heatmap_depth_240180;2 14.5
|
||||
ml_Heatmap_depth_180240;2 7
|
||||
ml_video_edit_hair_dyeing_segmodel_v3 0.5
|
||||
ml_video_edit_hairline_segmentation;3 1.5
|
||||
ml_video_edit_seg_320 0.5
|
||||
hiai_machine_vision_jfr_newmodel_2730_houduan_yolo 5
|
||||
hiai_machine_vision_mobileNet101_nosoftce_mobilenet_resnet 7.5
|
||||
ml_video_edit_person_divison_video;2 38
|
||||
ml_video_edit_person_divison_video;2 42
|
||||
ml_video_edit_hair_dyeing_segmodel_20211119 0.5
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
# [second column]:accuracy limit for float16 in arm64 device
|
||||
mtk_detect-mbv2-shortcut-400-400-simplified.onnx 4
|
||||
mtk_face_features_v3.onnx 20
|
||||
emotion-ferplus-8.onnx 1
|
||||
emotion-ferplus-8.onnx 1.5
|
||||
#rcnn-ilsvrc13-9.onnx 0.1
|
||||
efficientnet-lite4-11.onnx 2
|
||||
mobilenetv2-7.onnx 8
|
||||
|
@ -27,7 +27,7 @@ mnist-8.onnx 10
|
|||
crnn_lite_lstm_v2.onnx;1;32,32,32,1 0.3
|
||||
#psenet_lite_mbv2.onnx;1;1,32,32,3 0.6
|
||||
#occasionally aborted
|
||||
super-resolution-10.onnx;1;1,224,224,1 4.5
|
||||
super-resolution-10.onnx;1;1,224,224,1 5
|
||||
tinyyolov2-8.onnx;1;1,416,416,3 5.5
|
||||
#ml_2012_ocr_cn.onnx -1
|
||||
#ml_2012_ocr_cn_noLSTM.onnx 1
|
||||
|
@ -52,10 +52,10 @@ ml_video_edit_style_transfer_autoportrait.onnx 2
|
|||
ml_video_edit_style_transfer_candy.onnx 2
|
||||
ml_video_edit_style_transfer_gongnongbing.onnx 2
|
||||
ml_video_edit_style_transfer_starry.onnx 2
|
||||
hdc_Face_Landmark5_MTI_Aesthetic.onnx 0.5
|
||||
hdc_Face_Landmark5_MTI_Aesthetic.onnx 1
|
||||
hdc_Image_Aesthetic_MTI_Aesthetic.onnx 0.5
|
||||
hdc_resnet_1w_class.onnx 6
|
||||
gts_text_detection.onnx;1;1,224,224,3 10
|
||||
gts_text_detection.onnx;1;1,224,224,3 11
|
||||
hdc_Face_Emotion_MTI_Aesthetic.onnx 144
|
||||
ml_video_edit_imitate_filter.onnx 120
|
||||
ml_facedetector.onnx 6
|
||||
|
@ -71,7 +71,7 @@ mtk_emotions-d2012-75.onnx 6
|
|||
mtk_detect-mbv1-shortcut-400-400.onnx 0.5
|
||||
mtk_detect-mbv2-shortcut-400-400.onnx 0.5
|
||||
mtk_detect_mbv1_640_480.onnx 0.5
|
||||
mtk_detect-deeper-halfdeeper-mbv1-shortcut-400-400_nopostprocess_simplified_onnx.onnx 2
|
||||
mtk_detect-deeper-halfdeeper-mbv1-shortcut-400-400_nopostprocess_simplified_onnx.onnx 2.5
|
||||
mtk_detect-mbv1-shortcut-400-400_nopostprocess_simplified_onnx.onnx 6.5
|
||||
mtk_detect-deeper-halfdeeper-mbv1-lastearlySSD-shortcut-400-400_nopostprocess_simplified_onnx.onnx 2.5
|
||||
mtk_detect_mbv1_640_480_nopostprocess_simplified_onnx.onnx;1;1,480,640,3 2
|
||||
|
@ -87,7 +87,7 @@ Q888_iris_detect.onnx 0.5
|
|||
ssd_mobilenet_v1_10.onnx;1;1,383,640,3 0.5
|
||||
# The output from a conv in the later part contains many minus values, the following leakyRelu makes them become very
|
||||
# close to 0 (-e^-4). The fp16 precision lost a lot in this case and it affects the following computation.
|
||||
Harmony_Voiceprint.onnx;1;1,200,40,1 21.5 # small output causes big bias
|
||||
Harmony_Voiceprint.onnx;1;1,200,40,1 68 # small output causes big bias
|
||||
# A matmul op in the later part produces overflowed output values (>65504).
|
||||
#ml_video_edit_art_generate_20210513.onnx nan
|
||||
# bn_fusion causes a big bias(maybe random), need to debug later: The original bias is 2.1
|
||||
|
@ -104,11 +104,11 @@ ml_video_edit_makeup_mobilenetv203.onnx 4
|
|||
# The input of ml_video_edit_hair_dyeing_migrate_v2.onnx should be between [0, 1]
|
||||
ml_video_edit_hair_dyeing_migrate_v2.onnx;4 2.5
|
||||
Q888_CV_face_recognition_self.onnx 3.9
|
||||
ml_video_edit_hair_dyeing_migrate_v2_fix.onnx;4 3
|
||||
ml_video_edit_hair_dyeing_migrate_v2_fix.onnx;4 3.5
|
||||
ml_intelligent_cockpit_model.onnx;3;1,32:1,32:1,32 3.8
|
||||
CloudBU_FSRCNN_RTC_8ch_3450_QP9.onnx;1;1,225,225,3 1.5
|
||||
CloudBU_rfdn_rtc_x2_ver2_13.onnx;1;1,225,225,3 1.5
|
||||
CloudBU_rfdn_rtc_x2_ver2_3450.onnx;1;1,225,225,3 3.0
|
||||
CloudBU_rfdn_rtc_x2_ver2_3450.onnx;1;1,225,225,3 4
|
||||
ml_motion_capture_nanodet_m_0.5x_people_0928_sim.onnx 8
|
||||
ml_motion_capture_smpl_0916.onnx;3
|
||||
ml_motion_capture_spin_mobile_mv3_v3_57mm_sim.onnx;5 18
|
||||
|
@ -116,7 +116,7 @@ ml_video_edit_dimming_tech_model_345000_color.onnx;2 2
|
|||
Ireland_ulfgf.onnx;1;1,240,320,3
|
||||
Ireland_gaze_corrector.onnx;3 15
|
||||
Ireland_face_detector.onnx 2
|
||||
Ireland_gaze_estimator_ng.onnx 6
|
||||
Ireland_gaze_estimator_ng.onnx 8
|
||||
carbu_intelligent_cockpit_fasttext_best.onnx 0.5
|
||||
ml_video_edit_shot_selection_yolox_nano_coco_reduced.onnx 3
|
||||
ml_video_edit_shot_selection_face_emotion.onnx 0.7
|
||||
|
|
|
@ -18,7 +18,7 @@ hiai_ssd_mobilenetv2_object.pb 15
|
|||
hiai_humanDetection.pb 3.5
|
||||
hiai_PoseEstimation_Pcm.pb 0.5
|
||||
# The last layer has a very small value, which leads to a large error
|
||||
hiai_cn_recognize_modify_padv2.pb;1;1,32,512,1 27
|
||||
hiai_cn_recognize_modify_padv2.pb;1;1,32,512,1 37
|
||||
hiai_model_normalize_object_scene_ps_20200519.pb;1;1,224,224,3 17.1
|
||||
# The output of mtk_model_ckpt.pb has small value
|
||||
mtk_model_ckpt.pb 19.5
|
||||
|
@ -33,14 +33,14 @@ mtk_face_features_v1.pb 26
|
|||
model_normalize_object_scene_ps_20200519.pb;1;1,224,224,3 10
|
||||
hiai_AADB_HADB_MBV2_model.pb;1;1,224,224,3 6
|
||||
hiai_frozen_inference_graph.pb 12
|
||||
hiai_lm_inference_graph.pb 1.2
|
||||
hiai_lm_inference_graph.pb 1.5
|
||||
hiai_ghostnet.pb 0.9
|
||||
hiai_face_model_npu.pb 0.5
|
||||
hiai_cv_focusShootOCRModel_02.pb 10.5
|
||||
hiai_cv_focusShootOCRModel_02.pb 12.5
|
||||
hiai_label_and_video.pb;1;1,224,224,3 23
|
||||
hiai_dress_detect.pb;1;1,960,960,3 1.5
|
||||
hiai_iMaxDN_RGB.pb 0.5
|
||||
hiai_iMaxSR_RGB.pb 3.5
|
||||
hiai_iMaxSR_RGB.pb 5
|
||||
hiai_ctpn_feature_map.pb 6.5
|
||||
hiai_cpu_face_gazing.pb 0.5
|
||||
hiai_cpu_face_emotion.pb 2.2
|
||||
|
@ -49,7 +49,7 @@ Q_dila-small-mix-full-fineturn-390000-nopixel-nosigmoid.pb 1.5
|
|||
# The input of Q_crnn_ori_75w_slim model is between 0-255, but its outputs has small values (e-6).
|
||||
Q_crnn_ori_75w_slim_norm.pb 37
|
||||
# The output of Q_crnn_ori_v2 model has small values (e-4).
|
||||
Q_crnn_ori_v2_405001_notrans_nopre.pb 24
|
||||
Q_crnn_ori_v2_405001_notrans_nopre.pb 33
|
||||
# The input of hiai_latin models are between 0-255
|
||||
hiai_latin_ocr.pb 4
|
||||
hiai_latin_ocr_1.pb 3.5
|
||||
|
@ -68,7 +68,7 @@ ml_vision_guide_detection2.pb;1;1,320,320,1 1
|
|||
ml_tts_encoder.pb;4;1,44:1:1:1 9
|
||||
# encoder_0111_control_flow.pb is same as ml_tts_encoder_control_flow.pb
|
||||
#encoder_0111_control_flow.pb;4;1:1,44:1:1 10
|
||||
ml_video_edit_video_segment_gauss_adaptis_part2.pb;2 12.1
|
||||
ml_video_edit_video_segment_gauss_adaptis_part2.pb;2 16
|
||||
ml_video_edit_img_segment_adaptise.pb;2 40
|
||||
ml_video_edit_oneclick_adaptis.pb;3 6
|
||||
#decoder_step_201217.pb is the same model as ml_tts_decoder.pb.
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# content after ";" can be omitted.
|
||||
# [second column]:accuracy limit for float16 in arm64 device
|
||||
hiai_model_0909_kd_rot_ps_softmax.tflite 10
|
||||
hiai_chinese_english_recognize_model_float32.tflite 13
|
||||
hiai_chinese_english_recognize_model_float32.tflite 13.5
|
||||
hiai_bigmodel_ghost_2_1_no_normalized_no_trans_tflite.tflite 10
|
||||
hiai_bigmodel_ghost_5_1_no_normalized_no_trans_tflite.tflite 10
|
||||
hiai_cn_recognize_modify_padv2.tflite 14
|
||||
|
@ -64,7 +64,7 @@ inception_resnet_v2.tflite 10
|
|||
ml_ocr_latin.tflite 15
|
||||
hiai_PoseEstimation_Pcm.tflite 15
|
||||
hiai_ssd_mobilenetv2_object.tflite 60
|
||||
hiai_cv_focusShootOCRModel_02.tflite 13
|
||||
hiai_cv_focusShootOCRModel_02.tflite 13.5
|
||||
hiai_cv_poseEstimation.tflite 190
|
||||
inception_v4.tflite 10
|
||||
mtk_model_normalize_object_scene_ps_20200519_f16.tflite 10
|
||||
|
@ -129,8 +129,8 @@ mtk_pose.tflite 2
|
|||
mtk_model_emotions_0727_nosoftmax.tflite 2
|
||||
mtk_model_normalize_object_scene_ps_20200826_f32_no_softmax.tflite 22
|
||||
mtk_276landmark_0913.tflite 16
|
||||
mtk_face_recognition.tflite 8
|
||||
mtk_convert_model.tflite 5.3
|
||||
mtk_face_recognition.tflite 11
|
||||
mtk_convert_model.tflite 7.5
|
||||
smartreply.tflite 0.1
|
||||
mindspore_text_classification_tflite.tflite 9.2 # small output causes big bias
|
||||
#ml_location.tflite 0.1
|
||||
|
@ -176,7 +176,7 @@ Q_convert.tflite 12
|
|||
# the input of Q_crnn_ori_75w_slim model is between 0-255, but its outputs has small values (e-6).
|
||||
Q_crnn_ori_75w_slim_norm_pb2tflite.tflite 29
|
||||
# the output of Q_crnn_ori_v2 model has small values (e-4).
|
||||
Q_crnn_ori_v2_405001_notrans_nopre_pb2tflite.tflite 36
|
||||
Q_crnn_ori_v2_405001_notrans_nopre_pb2tflite.tflite 42
|
||||
# the inputs of two Q_crnn_screen_slim400w models are between 0-255, but their outputs have small values (e-7).
|
||||
Q_crnn_screen_slim400w_more_20w_pb2tflite.tflite 71
|
||||
Q_dila-small-mix-full-fineturn-390000-nopixel-nosigmoid_tflite.tflite 1.5
|
||||
|
@ -202,14 +202,14 @@ Q888_model_normalize_object_scene_ps_20200826_f32_no_softmax.tflite 2
|
|||
# input data: -1~1
|
||||
Q888_face_emo_dress_mv3_orderd.tflite 2.5
|
||||
Q_iMaxDN_RGB_385_p_RGB_RGB_pb2tflite.tflite 1
|
||||
Q_iMaxSR_RGB_385_p_pb2tflite.tflite 5
|
||||
Q_iMaxSR_RGB_385_p_pb2tflite.tflite 5.5
|
||||
bloom_new_detect.tflite 3.5
|
||||
bloom_model_age_gender.tflite 0.5
|
||||
bloom_isface.tflite 0.5
|
||||
# The output values of conv layers range from -e±5 to e±5, which almost reaches the representation limit of fp16. In
|
||||
# this range, the fp16 data will has big bias. And the accumulation of this bias lowers the final precision.
|
||||
hiai_object_detect_814.tflite 14
|
||||
ml_video_edit_video_segment_gauss_adaptis_part2_pb2tflite.tflite;2 12.1
|
||||
ml_video_edit_video_segment_gauss_adaptis_part2_pb2tflite.tflite;2 16
|
||||
ml_video_edit_img_segment_adaptise_pb2tflite.tflite;2 0.5
|
||||
hdc_tb_cn_neg.tflite;3 295
|
||||
# The input of hiai_cv_labelDetectorModel_v3.tflite is between 0-255.
|
||||
|
|
Loading…
Reference in New Issue