[MSLITE][DEVELOP] optimize conv winograd

This commit is contained in:
yangruoqi713 2022-01-17 14:18:46 +08:00
parent ad5c5ce5f8
commit 4b0edb34ce
26 changed files with 1521 additions and 231 deletions

View File

@ -195,3 +195,4 @@ mindspore/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition
mindspore/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition.cc:mindspore::parallel::PartitionNode
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/instance_norm_fp16.c:InstanceNormNC8HW8Fp16
mindspore/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc:mindspore::kernel::MatmulFp32BaseCPUKernel::init_global_variable
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_winograd_fp32.c:ConvWinogardFp32

View File

@ -187,7 +187,7 @@ void Conv1x1OutNc8hw8MultiThreadByWeightFp16(const float16_t *input, float16_t *
// fp16 convolution winograd
void ConvWinogardFp16(const float16_t *input_data, const float16_t *trans_weight, const float16_t *bias_data,
float16_t *output_data, TmpBufferAddressFp16 *buffer_list, int task_id,
const ConvParameter *conv_param, InputTransFp16Func in_func, OutputTransFp16Func out_func) {
const ConvParameter *conv_param, TransFp16FuncList trans_func) {
#ifdef ENABLE_ARM64
const int tile_num = 16;
#else
@ -196,6 +196,7 @@ void ConvWinogardFp16(const float16_t *input_data, const float16_t *trans_weight
NNACL_CHECK_ZERO_RETURN(conv_param->output_unit_);
NNACL_CHECK_ZERO_RETURN(conv_param->thread_num_);
int in_channel = conv_param->input_channel_;
int input_unit = conv_param->input_unit_;
int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_);
int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_);
int output_count = out_w_block * out_h_block;
@ -204,16 +205,12 @@ void ConvWinogardFp16(const float16_t *input_data, const float16_t *trans_weight
NNACL_CHECK_ZERO_RETURN(real_tile);
int output_tile_count = UP_DIV(output_count, real_tile);
int oc8 = UP_DIV(conv_param->output_channel_, C8NUM);
int input_unit_square = conv_param->input_unit_ * conv_param->input_unit_;
int input_unit_square = input_unit * input_unit;
float16_t *trans_input = buffer_list[0];
float16_t *gemm_out = buffer_list[1];
float16_t *tmp_data = buffer_list[2];
float16_t *col_buffer = buffer_list[3];
int trans_input_offset = tile_num * input_unit_square * in_channel;
int gemm_out_offset = tile_num * input_unit_square * oc8 * C8NUM;
int tmp_data_offset = input_unit_square * C8NUM;
int col_buffer_offset = tile_num * in_channel;
float16_t *trans_input = buffer_list[0] + task_id * tile_num * input_unit_square * in_channel;
float16_t *gemm_out = buffer_list[1] + task_id * tile_num * input_unit_square * oc8 * C8NUM;
float16_t *tmp_data = buffer_list[2] + task_id * input_unit_square * C8NUM;
float16_t *col_buffer = buffer_list[3] + task_id * tile_num * in_channel;
// step 1 : filter transform (pre-processed offline)
// step 2 : input transform (online)
for (int b = 0; b < conv_param->input_batch_; b++) {
@ -226,30 +223,64 @@ void ConvWinogardFp16(const float16_t *input_data, const float16_t *trans_weight
if (cal_num <= 0) {
return;
}
WinogradInputTransformFp16(input_data + in_batch_offset, trans_input + task_id * trans_input_offset,
tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param,
in_func);
// step 3 : gemm
float16_t *src_ptr = trans_input + task_id * trans_input_offset;
float16_t *dst_ptr = gemm_out + task_id * gemm_out_offset;
float16_t *tmp_col_ptr = col_buffer + task_id * col_buffer_offset;
for (int i = 0; i < input_unit_square; ++i) {
#ifdef ENABLE_ARM64
RowMajor2Col16MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel);
// Optimize input transform. Only valid for arm64, the tile num is 16.
// For arm32, the tile_num is 12. The function(InputTransform4x4Pack12Fp16) needs to be rewritten.
bool fused_pack =
(cal_num == tile_num) && (trans_func.in_step_func_ != NULL) && (trans_func.in_pack_func_ != NULL);
if (fused_pack) {
float16_t *opt_trans_input =
buffer_list[4] + task_id * tile_num * input_unit_square * UP_ROUND(in_channel, C8NUM);
WinogradInputTransformOptStepFp16(input_data + in_batch_offset, opt_trans_input, tmp_data, cal_num,
out_tile_index, out_w_block, conv_param, trans_func.in_step_func_);
for (int w_index = 0; w_index < input_unit; w_index++) {
float16_t *src_w = opt_trans_input + w_index * input_unit * tile_num * C8NUM;
for (int c = 0; c < UP_DIV(in_channel, C8NUM); c++) {
int real_c = in_channel - c * C8NUM;
real_c = real_c > C8NUM ? C8NUM : real_c;
float16_t *src_c = src_w + c * input_unit_square * tile_num * C8NUM;
float16_t *dst_c = trans_input + c * tile_num * C8NUM;
trans_func.in_pack_func_(src_c, dst_c, C8NUM, in_channel * tile_num, real_c);
}
for (int h_index = 0; h_index < input_unit; h_index++) {
const float16_t *gemm_input = trans_input + h_index * tile_num * in_channel;
int point_index = h_index * input_unit + w_index;
const float16_t *gemm_weight = trans_weight + point_index * in_channel * oc8 * C8NUM;
MatMulFp16(gemm_input, gemm_weight, gemm_out + point_index * C8NUM, NULL, 0, in_channel, cal_num,
oc8 * C8NUM, input_unit_square, OutType_TileC8);
}
}
} else {
#endif
WinogradInputTransformFp16(input_data + in_batch_offset, trans_input, tmp_data, cal_num, out_tile_index,
out_w_block, conv_param, trans_func.in_func_);
// step 3 : gemm
float16_t *src_ptr = trans_input;
float16_t *dst_ptr = gemm_out;
float16_t *tmp_col_ptr = col_buffer;
for (int i = 0; i < input_unit_square; ++i) {
#ifdef ENABLE_ARM64
RowMajor2Col16MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel);
#else
RowMajor2Col12MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel);
#endif
MatMulFp16(tmp_col_ptr, trans_weight + i * in_channel * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, in_channel,
cal_num, oc8 * C8NUM, input_unit_square, OutType_TileC8);
MatMulFp16(tmp_col_ptr, trans_weight + i * in_channel * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, in_channel,
cal_num, oc8 * C8NUM, input_unit_square, OutType_TileC8);
}
#ifdef ENABLE_ARM64
}
#endif
// step 4 : output transform
if (conv_param->out_format_ != NNACL_NC4HW4) { // nc4hw4
WinogradOutputNHWCTransformFp16(gemm_out + task_id * gemm_out_offset, output_data + out_batch_offset, bias_data,
cal_num, out_tile_index, out_w_block, conv_param, out_func);
WinogradOutputNHWCTransformFp16(gemm_out, output_data + out_batch_offset, bias_data, cal_num, out_tile_index,
out_w_block, conv_param, trans_func.out_func_);
} else {
WinogradOutputNC8HW8TransformFp16(gemm_out + task_id * gemm_out_offset, output_data + out_batch_offset,
bias_data, cal_num, out_tile_index, out_w_block, conv_param, out_func);
WinogradOutputNC8HW8TransformFp16(gemm_out, output_data + out_batch_offset, bias_data, cal_num, out_tile_index,
out_w_block, conv_param, trans_func.out_func_);
}
}
}

View File

@ -49,7 +49,7 @@ void Conv1x1OutNc8hw8MultiThreadByWeightFp16(const float16_t *input, float16_t *
// fp16 convolution winograd
void ConvWinogardFp16(const float16_t *input_data, const float16_t *trans_weight, const float16_t *bias_data,
float16_t *output_data, TmpBufferAddressFp16 *buffer_list, int task_id,
const ConvParameter *conv_param, InputTransFp16Func in_func, OutputTransFp16Func out_func);
const ConvParameter *conv_param, TransFp16FuncList trans_func);
#ifdef __cplusplus
}

View File

@ -16,6 +16,77 @@
#include "nnacl/fp16/winograd_transform_fp16.h"
void PrepareTransInputFp16(const float16_t *src_data, float16_t *dst_data, int interval_x_s, int interval_x_e,
int interval_y_s, int interval_y_e, int real_c, const ConvParameter *conv_param) {
int input_unit = conv_param->input_unit_;
int in_channel = conv_param->input_channel_;
int input_w = conv_param->input_w_;
// clear tmp buffer
if (interval_x_e - interval_x_s != input_unit || interval_y_e - interval_y_s != input_unit) {
memset(dst_data, 0, input_unit * input_unit * C8NUM * sizeof(float16_t));
}
// get real input block with padding
if (real_c == C8NUM) {
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
int src_y_offset = (interval * input_w + interval_x_s) * in_channel;
int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM;
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
int src_x_offset = src_y_offset + j * in_channel;
int dst_x_offset = dst_y_offset + j * C8NUM;
const float16_t *src_addr = src_data + src_x_offset;
float16_t *dst_addr = dst_data + dst_x_offset;
#ifdef ENABLE_NEON
vst1q_f16(dst_addr, vld1q_f16(src_addr));
#else
for (int k = 0; k < C8NUM; k++) {
dst_addr[k] = src_addr[k];
}
#endif
}
}
} else if (real_c < 8 && real_c >= 4) {
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
int src_y_offset = (interval * input_w + interval_x_s) * in_channel;
int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM;
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
int src_x_offset = src_y_offset + j * in_channel;
int dst_x_offset = dst_y_offset + j * C8NUM;
const float16_t *src_addr = src_data + src_x_offset;
float16_t *dst_addr = dst_data + dst_x_offset;
int rc = real_c - 4;
#ifdef ENABLE_NEON
vst1_f16(dst_addr, vld1_f16(src_addr));
#else
for (int k = 0; k < C4NUM; k++) {
dst_addr[k] = src_addr[k];
}
#endif
src_addr += 4;
dst_addr += 4;
for (int i = 0; i < rc; ++i) {
dst_addr[i] = src_addr[i];
}
}
}
} else {
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
int src_y_offset = (interval * input_w + interval_x_s) * in_channel;
int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM;
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
int src_x_offset = src_y_offset + j * in_channel;
int dst_x_offset = dst_y_offset + j * C8NUM;
const float16_t *src_addr = src_data + src_x_offset;
float16_t *dst_addr = dst_data + dst_x_offset;
for (int k = 0; k < real_c; k++) {
dst_addr[k] = src_addr[k];
}
}
}
}
}
// fp16 common winograd
void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num,
int out_tile_index, int out_w_block_num, const ConvParameter *conv_param,
@ -49,71 +120,11 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in
int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s);
int dst_plane_offset = c * in_channel;
for (int ic = 0; ic < ic8; ic++) {
// clear tmp buffer
memset(tmp_data, 0, input_unit * input_unit * C8NUM * sizeof(float16_t));
int real_c = in_channel - ic * C8NUM;
real_c = real_c > C8NUM ? C8NUM : real_c;
int src_ic8_offset = src_plane_offset + ic * C8NUM;
// get real input block with padding
if (real_c == C8NUM) {
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
int src_y_offset = src_ic8_offset + (interval * input_w + interval_x_s) * in_channel;
int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM;
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
int src_x_offset = src_y_offset + j * in_channel;
int dst_x_offset = dst_y_offset + j * C8NUM;
const float16_t *src_addr = input_data + src_x_offset;
float16_t *dst_addr = tmp_data + dst_x_offset;
#ifdef ENABLE_NEON
vst1q_f16(dst_addr, vld1q_f16(src_addr));
#else
for (int k = 0; k < C8NUM; k++) {
dst_addr[k] = src_addr[k];
}
#endif
}
}
} else if (real_c < 8 && real_c >= 4) {
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
int src_y_offset = src_ic8_offset + (interval * input_w + interval_x_s) * in_channel;
int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM;
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
int src_x_offset = src_y_offset + j * in_channel;
int dst_x_offset = dst_y_offset + j * C8NUM;
const float16_t *src_addr = input_data + src_x_offset;
float16_t *dst_addr = tmp_data + dst_x_offset;
int rc = real_c - 4;
#ifdef ENABLE_NEON
vst1_f16(dst_addr, vld1_f16(src_addr));
#else
for (int k = 0; k < C4NUM; k++) {
dst_addr[k] = src_addr[k];
}
#endif
src_addr += 4;
dst_addr += 4;
for (int i = 0; i < rc; ++i) {
dst_addr[i] = src_addr[i];
}
}
}
} else {
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
int src_y_offset = src_ic8_offset + (interval * input_w + interval_x_s) * in_channel;
int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM;
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
int src_x_offset = src_y_offset + j * in_channel;
int dst_x_offset = dst_y_offset + j * C8NUM;
const float16_t *src_addr = input_data + src_x_offset;
float16_t *dst_addr = tmp_data + dst_x_offset;
for (int k = 0; k < real_c; k++) {
dst_addr[k] = src_addr[k];
}
}
}
}
const float16_t *src_data = input_data + src_plane_offset + ic * C8NUM;
PrepareTransInputFp16(src_data, tmp_data, interval_x_s, interval_x_e, interval_y_s, interval_y_e, real_c,
conv_param);
// input transform
int dst_ic8_offset = dst_plane_offset + ic * C8NUM;
@ -125,6 +136,51 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in
} // cal_tile_num loop
}
// Only support arm64
void WinogradInputTransformOptStepFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data,
int cal_num, int out_tile_index, int out_w_block_num,
const ConvParameter *conv_param, InputTransStepFp16Func func) {
const int tile_num = 16;
int input_unit = conv_param->input_unit_;
int output_unit = conv_param->output_unit_;
int in_channel = conv_param->input_channel_;
int ic8 = UP_DIV(in_channel, C8NUM);
int pad_h = conv_param->pad_u_;
int pad_w = conv_param->pad_l_;
int input_h = conv_param->input_h_;
int input_w = conv_param->input_w_;
if (out_w_block_num == 0) {
return;
}
for (int c = 0; c < cal_num; c++) { // actual tiled number
int src_x_s = (out_tile_index % out_w_block_num) * output_unit - pad_w;
int src_y_s = (out_tile_index / out_w_block_num) * output_unit - pad_h;
int interval_x_s = src_x_s > 0 ? 0 : -src_x_s;
int interval_y_s = src_y_s > 0 ? 0 : -src_y_s;
int src_x_e = src_x_s + input_unit;
int src_y_e = src_y_s + input_unit;
int interval_x_e = src_x_e < input_w ? input_unit : (input_w - src_x_s);
int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s);
int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s);
int dst_plane_offset = c * C8NUM;
for (int ic = 0; ic < ic8; ic++) {
int real_c = in_channel - ic * C8NUM;
real_c = real_c > C8NUM ? C8NUM : real_c;
const float16_t *src_data = input_data + src_plane_offset + ic * C8NUM;
PrepareTransInputFp16(src_data, tmp_data, interval_x_s, interval_x_e, interval_y_s, interval_y_e, real_c,
conv_param);
// input transform
int dst_ic8_offset = dst_plane_offset + ic * tile_num * input_unit * input_unit * C8NUM;
size_t dst_step = input_unit * tile_num * C8NUM;
float16_t *trans_input_ptr = trans_input + dst_ic8_offset;
func(tmp_data, trans_input_ptr, C8NUM, dst_step, tile_num * C8NUM);
}
out_tile_index++;
} // cal_tile_num loop
}
void WinogradOutputNHWCTransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data,
int cal_num, int out_tile_index, int output_unit_num,
const ConvParameter *conv_param, OutputTransFp16Func func) {

View File

@ -33,6 +33,10 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in
int out_tile_index, int out_w_block_num, const ConvParameter *conv_param,
InputTransFp16Func func);
void WinogradInputTransformOptStepFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data,
int cal_num, int out_tile_index, int out_w_block_num,
const ConvParameter *conv_param, InputTransStepFp16Func func);
void WinogradOutputNHWCTransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data,
int cal_num, int out_tile_index, int output_unit_num,
const ConvParameter *conv_param, OutputTransFp16Func func);

View File

@ -20,6 +20,38 @@
#define MIN_UNIT_FP16 2
#define MAX_UNIT_FP16 4
#ifdef ENABLE_ARM64
void transpose8(float16x8_t *s0, float16x8_t *s1, float16x8_t *s2, float16x8_t *s3, float16x8_t *s4, float16x8_t *s5,
float16x8_t *s6, float16x8_t *s7) {
float32x4_t m0 = (float32x4_t)(vtrn1q_f16(*s0, *s1));
float32x4_t m1 = (float32x4_t)(vtrn2q_f16(*s0, *s1));
float32x4_t m2 = (float32x4_t)(vtrn1q_f16(*s2, *s3));
float32x4_t m3 = (float32x4_t)(vtrn2q_f16(*s2, *s3));
float32x4_t m4 = (float32x4_t)(vtrn1q_f16(*s4, *s5));
float32x4_t m5 = (float32x4_t)(vtrn2q_f16(*s4, *s5));
float32x4_t m6 = (float32x4_t)(vtrn1q_f16(*s6, *s7));
float32x4_t m7 = (float32x4_t)(vtrn2q_f16(*s6, *s7));
float64x2_t t0 = (float64x2_t)(vtrn1q_f32(m0, m2));
float64x2_t t2 = (float64x2_t)(vtrn2q_f32(m0, m2));
float64x2_t t1 = (float64x2_t)(vtrn1q_f32(m1, m3));
float64x2_t t3 = (float64x2_t)(vtrn2q_f32(m1, m3));
float64x2_t t4 = (float64x2_t)(vtrn1q_f32(m4, m6));
float64x2_t t6 = (float64x2_t)(vtrn2q_f32(m4, m6));
float64x2_t t5 = (float64x2_t)(vtrn1q_f32(m5, m7));
float64x2_t t7 = (float64x2_t)(vtrn2q_f32(m5, m7));
*s0 = (float16x8_t)(vtrn1q_f64(t0, t4));
*s4 = (float16x8_t)(vtrn2q_f64(t0, t4));
*s1 = (float16x8_t)(vtrn1q_f64(t1, t5));
*s5 = (float16x8_t)(vtrn2q_f64(t1, t5));
*s2 = (float16x8_t)(vtrn1q_f64(t2, t6));
*s6 = (float16x8_t)(vtrn2q_f64(t2, t6));
*s3 = (float16x8_t)(vtrn1q_f64(t3, t7));
*s7 = (float16x8_t)(vtrn2q_f64(t3, t7));
}
#endif
static InputTransFp16Func InputTransFp16FuncList[] = {
NULL, NULL, NULL, NULL, InputTransform4x4UnitFp16, NULL, InputTransform6x6UnitFp16, NULL, InputTransform8x8UnitFp16};
@ -81,6 +113,25 @@ static OutputTransFp16Func OutputTransFp16FuncRelu6List8[] = {NULL,
InputTransFp16Func GetInputTransFp16Func(int input_unit) { return InputTransFp16FuncList[input_unit]; }
#ifdef ENABLE_ARM64
static InputTransStepFp16Func InputTransStepFp16FuncList[] = {
NULL, NULL, NULL, NULL, InputTransform4x4StepFp16, NULL, InputTransform6x6StepFp16, NULL, InputTransform8x8StepFp16};
static InputTransPackFp16Func InputTransPackFp16FuncList[] = {NULL,
NULL,
NULL,
NULL,
InputTransform4x4Pack16Fp16,
NULL,
InputTransform6x6Pack16Fp16,
NULL,
InputTransform8x8Pack16Fp16};
InputTransStepFp16Func GetInputTransStepFp16Func(int input_unit) { return InputTransStepFp16FuncList[input_unit]; }
InputTransPackFp16Func GetInputTransPackFp16Func(int input_unit) { return InputTransPackFp16FuncList[input_unit]; }
#endif
void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) {
int j = 0;
if (real_c == 8) {
@ -160,6 +211,74 @@ void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, i
}
}
void InputTransform4x4StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step,
int dst_row_step) {
for (int l = 0; l < 4; ++l) {
const float16_t *src_ptr = src_data + l * 4 * src_step;
float16_t *dst_ptr = dst_data + l * dst_row_step;
float16x8_t s0 = vld1q_f16(src_ptr + 0 * src_step);
float16x8_t s1 = vld1q_f16(src_ptr + 1 * src_step);
float16x8_t s2 = vld1q_f16(src_ptr + 2 * src_step);
float16x8_t s3 = vld1q_f16(src_ptr + 3 * src_step);
float16x8_t m0 = vsubq_f16(s0, s2);
float16x8_t m1 = vaddq_f16(s1, s2);
float16x8_t m2 = vsubq_f16(s2, s1);
float16x8_t m3 = vsubq_f16(s3, s1);
vst1q_f16(dst_ptr + 0 * dst_step, m0);
vst1q_f16(dst_ptr + 1 * dst_step, m1);
vst1q_f16(dst_ptr + 2 * dst_step, m2);
vst1q_f16(dst_ptr + 3 * dst_step, m3);
}
}
#ifdef ENABLE_ARM64
void InputTransform4x4Pack16ChannelFp16(float16_t *src_ptr, float16_t *dst_ptr, int dst_step, int pack_tile,
int src_point_stride) {
LOAD_LINE_DATA_FP16(0);
LOAD_LINE_DATA_FP16(1);
LOAD_LINE_DATA_FP16(2);
LOAD_LINE_DATA_FP16(3);
float16x8_t m0 = vsubq_f16(s00, s20);
float16x8_t m1 = vsubq_f16(s01, s21);
vst1q_f16(dst_ptr + 0 * dst_step + 0 * pack_tile, m0);
vst1q_f16(dst_ptr + 0 * dst_step + 1 * pack_tile, m1);
m0 = vaddq_f16(s10, s20);
m1 = vaddq_f16(s11, s21);
vst1q_f16(dst_ptr + 1 * dst_step + 0 * pack_tile, m0);
vst1q_f16(dst_ptr + 1 * dst_step + 1 * pack_tile, m1);
m0 = vsubq_f16(s20, s10);
m1 = vsubq_f16(s21, s11);
vst1q_f16(dst_ptr + 2 * dst_step + 0 * pack_tile, m0);
vst1q_f16(dst_ptr + 2 * dst_step + 1 * pack_tile, m1);
m0 = vsubq_f16(s30, s10);
m1 = vsubq_f16(s31, s11);
vst1q_f16(dst_ptr + 3 * dst_step + 0 * pack_tile, m0);
vst1q_f16(dst_ptr + 3 * dst_step + 1 * pack_tile, m1);
}
void InputTransform4x4Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) {
int block_tile = 16;
int pack_tile = src_step;
int src_point_stride = block_tile * pack_tile;
for (int l = 0; l < 4; ++l) {
float16_t *src_ptr = src_data + l * C8NUM * block_tile;
TRANSPOSE_16x8;
}
for (int c = 0; c < real_c; ++c) {
float16_t *src_ptr = src_data + c * block_tile;
float16_t *dst_ptr = dst_data + c * block_tile;
InputTransform4x4Pack16ChannelFp16(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride);
}
}
#endif
void InputTransform6x6UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) {
int j = 0;
if (real_c == 8) {
@ -272,6 +391,95 @@ void InputTransform6x6UnitFp16(const float16_t *src_data, float16_t *dst_data, i
}
}
void InputTransform6x6StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step,
int dst_row_step) {
for (int l = 0; l < 6; ++l) {
const float16_t *src_ptr = src_data + l * 6 * src_step;
float16_t *dst_ptr = dst_data + l * dst_row_step;
float16x8_t s0 = vld1q_f16(src_ptr + 0 * src_step);
float16x8_t s1 = vld1q_f16(src_ptr + 1 * src_step);
float16x8_t s2 = vld1q_f16(src_ptr + 2 * src_step);
float16x8_t s3 = vld1q_f16(src_ptr + 3 * src_step);
float16x8_t s4 = vld1q_f16(src_ptr + 4 * src_step);
float16x8_t s5 = vld1q_f16(src_ptr + 5 * src_step);
float16x8_t tmp1 = vsubq_f16(s3, s1);
float16x8_t tmp2 = vsubq_f16(s4, s2);
float16x8_t m0 = vaddq_f16(vsubq_f16(vmulq_n_f16(s0, 4), vmulq_n_f16(s2, 5)), s4);
float16x8_t m1 = vaddq_f16(vmulq_n_f16(vaddq_f16(s1, s2), -4), vaddq_f16(s3, s4));
float16x8_t m2 = vaddq_f16(vmulq_n_f16(vsubq_f16(s1, s2), 4), vsubq_f16(s4, s3));
float16x8_t m3 = vaddq_f16(vmulq_n_f16(tmp1, 2), tmp2);
float16x8_t m4 = vaddq_f16(vmulq_n_f16(tmp1, -2), tmp2);
float16x8_t m5 = vaddq_f16(vsubq_f16(vmulq_n_f16(s1, 4), vmulq_n_f16(s3, 5)), s5);
vst1q_f16(dst_ptr + 0 * dst_step, m0);
vst1q_f16(dst_ptr + 1 * dst_step, m1);
vst1q_f16(dst_ptr + 2 * dst_step, m2);
vst1q_f16(dst_ptr + 3 * dst_step, m3);
vst1q_f16(dst_ptr + 4 * dst_step, m4);
vst1q_f16(dst_ptr + 5 * dst_step, m5);
}
}
#ifdef ENABLE_ARM64
void InputTransform6x6Pack16ChannelFp16(float16_t *src_ptr, float16_t *dst_ptr, int dst_step, int pack_tile,
int src_point_stride) {
LOAD_LINE_DATA_FP16(0);
LOAD_LINE_DATA_FP16(1);
LOAD_LINE_DATA_FP16(2);
LOAD_LINE_DATA_FP16(3);
LOAD_LINE_DATA_FP16(4);
LOAD_LINE_DATA_FP16(5);
float16x8_t m0 = vaddq_f16(vsubq_f16(vmulq_n_f16(s00, 4), vmulq_n_f16(s20, 5)), s40);
float16x8_t m1 = vaddq_f16(vsubq_f16(vmulq_n_f16(s01, 4), vmulq_n_f16(s21, 5)), s41);
vst1q_f16(dst_ptr + 0 * dst_step + 0 * pack_tile, m0);
vst1q_f16(dst_ptr + 0 * dst_step + 1 * pack_tile, m1);
m0 = vaddq_f16(vmulq_n_f16(vaddq_f16(s10, s20), -4), vaddq_f16(s30, s40));
m1 = vaddq_f16(vmulq_n_f16(vaddq_f16(s11, s21), -4), vaddq_f16(s31, s41));
vst1q_f16(dst_ptr + 1 * dst_step + 0 * pack_tile, m0);
vst1q_f16(dst_ptr + 1 * dst_step + 1 * pack_tile, m1);
m0 = vaddq_f16(vmulq_n_f16(vsubq_f16(s10, s20), 4), vsubq_f16(s40, s30));
m1 = vaddq_f16(vmulq_n_f16(vsubq_f16(s11, s21), 4), vsubq_f16(s41, s31));
vst1q_f16(dst_ptr + 2 * dst_step + 0 * pack_tile, m0);
vst1q_f16(dst_ptr + 2 * dst_step + 1 * pack_tile, m1);
m0 = vaddq_f16(vmulq_n_f16(vsubq_f16(s30, s10), 2), vsubq_f16(s40, s20));
m1 = vaddq_f16(vmulq_n_f16(vsubq_f16(s31, s11), 2), vsubq_f16(s41, s21));
vst1q_f16(dst_ptr + 3 * dst_step + 0 * pack_tile, m0);
vst1q_f16(dst_ptr + 3 * dst_step + 1 * pack_tile, m1);
m0 = vaddq_f16(vmulq_n_f16(vsubq_f16(s30, s10), -2), vsubq_f16(s40, s20));
m1 = vaddq_f16(vmulq_n_f16(vsubq_f16(s31, s11), -2), vsubq_f16(s41, s21));
vst1q_f16(dst_ptr + 4 * dst_step + 0 * pack_tile, m0);
vst1q_f16(dst_ptr + 4 * dst_step + 1 * pack_tile, m1);
m0 = vaddq_f16(vsubq_f16(vmulq_n_f16(s10, 4), vmulq_n_f16(s30, 5)), s50);
m1 = vaddq_f16(vsubq_f16(vmulq_n_f16(s11, 4), vmulq_n_f16(s31, 5)), s51);
vst1q_f16(dst_ptr + 5 * dst_step + 0 * pack_tile, m0);
vst1q_f16(dst_ptr + 5 * dst_step + 1 * pack_tile, m1);
}
void InputTransform6x6Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) {
int block_tile = 16;
int pack_tile = src_step;
int src_point_stride = block_tile * pack_tile;
for (int l = 0; l < 6; ++l) {
float16_t *src_ptr = src_data + l * C8NUM * block_tile;
TRANSPOSE_16x8;
}
for (int c = 0; c < real_c; ++c) {
float16_t *src_ptr = src_data + c * block_tile;
float16_t *dst_ptr = dst_data + c * block_tile;
InputTransform6x6Pack16ChannelFp16(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride);
}
}
#endif
void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) {
int j = 0;
if (real_c == 8) {
@ -429,6 +637,133 @@ void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, i
}
}
void InputTransform8x8StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step,
int dst_row_step) {
for (int l = 0; l < 8; ++l) {
const float16_t *src_ptr = src_data + l * 8 * src_step;
float16_t *dst_ptr = dst_data + l * dst_row_step;
float16x8_t s0 = vld1q_f16(src_ptr + 0 * src_step);
float16x8_t s1 = vld1q_f16(src_ptr + 1 * src_step);
float16x8_t s2 = vld1q_f16(src_ptr + 2 * src_step);
float16x8_t s3 = vld1q_f16(src_ptr + 3 * src_step);
float16x8_t s4 = vld1q_f16(src_ptr + 4 * src_step);
float16x8_t s5 = vld1q_f16(src_ptr + 5 * src_step);
float16x8_t s6 = vld1q_f16(src_ptr + 6 * src_step);
float16x8_t s7 = vld1q_f16(src_ptr + 7 * src_step);
float16x8_t m0 =
vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(s0, 0.5625), vmulq_n_f16(s2, 3.0625)), vmulq_n_f16(s4, 3.5)), s6);
float16x8_t tmp1 = vaddq_f16(vmulq_n_f16(s1, 1.125), vmulq_n_f16(s5, 0.5));
float16x8_t tmp2 = vsubq_f16(vmulq_n_f16(s2, 2.25), vmulq_n_f16(s4, 3.25));
float16x8_t m1 = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(s3, 1.625)), s6);
float16x8_t m2 = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(s3, 1.625)), s6);
tmp1 = vaddq_f16(vmulq_n_f16(s1, 0.5625), s5);
tmp2 = vsubq_f16(vmulq_n_f16(s2, 0.5625), vmulq_n_f16(s4, 2.5));
float16x8_t m3 = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(s3, 2.5)), s6);
float16x8_t m4 = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(s3, 2.5)), s6);
tmp1 = vaddq_f16(vmulq_n_f16(s1, 0.375), vmulq_n_f16(s5, 1.5));
tmp2 = vsubq_f16(vmulq_n_f16(s2, 0.25), vmulq_n_f16(s4, 1.25));
float16x8_t m5 = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(s3, 1.875)), s6);
float16x8_t m6 = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(s3, 1.875)), s6);
float16x8_t m7 =
vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(s1, -0.5625), vmulq_n_f16(s3, 3.0625)), vmulq_n_f16(s5, 3.5)), s7);
vst1q_f16(dst_ptr + 0 * dst_step, m0);
vst1q_f16(dst_ptr + 1 * dst_step, m1);
vst1q_f16(dst_ptr + 2 * dst_step, m2);
vst1q_f16(dst_ptr + 3 * dst_step, m3);
vst1q_f16(dst_ptr + 4 * dst_step, m4);
vst1q_f16(dst_ptr + 5 * dst_step, m5);
vst1q_f16(dst_ptr + 6 * dst_step, m6);
vst1q_f16(dst_ptr + 7 * dst_step, m7);
}
}
#ifdef ENABLE_ARM64
void InputTransform8x8Pack16ChannelFp16(float16_t *src_ptr, float16_t *dst_ptr, int dst_step, int pack_tile,
int src_point_stride) {
LOAD_LINE_DATA_FP16(0);
LOAD_LINE_DATA_FP16(1);
LOAD_LINE_DATA_FP16(2);
LOAD_LINE_DATA_FP16(3);
LOAD_LINE_DATA_FP16(4);
LOAD_LINE_DATA_FP16(5);
LOAD_LINE_DATA_FP16(6);
LOAD_LINE_DATA_FP16(7);
float16x8_t m0 =
vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(s00, 0.5625), vmulq_n_f16(s20, 3.0625)), vmulq_n_f16(s40, 3.5)), s60);
float16x8_t m1 =
vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(s01, 0.5625), vmulq_n_f16(s21, 3.0625)), vmulq_n_f16(s41, 3.5)), s61);
vst1q_f16(dst_ptr + 0 * dst_step + 0 * pack_tile, m0);
vst1q_f16(dst_ptr + 0 * dst_step + 1 * pack_tile, m1);
float16x8_t tmp10 = vaddq_f16(vmulq_n_f16(s10, 1.125), vmulq_n_f16(s50, 0.5));
float16x8_t tmp11 = vaddq_f16(vmulq_n_f16(s11, 1.125), vmulq_n_f16(s51, 0.5));
float16x8_t tmp20 = vsubq_f16(vmulq_n_f16(s20, 2.25), vmulq_n_f16(s40, 3.25));
float16x8_t tmp21 = vsubq_f16(vmulq_n_f16(s21, 2.25), vmulq_n_f16(s41, 3.25));
m0 = vaddq_f16(vsubq_f16(vaddq_f16(tmp10, tmp20), vmulq_n_f16(s30, 1.625)), s60);
m1 = vaddq_f16(vsubq_f16(vaddq_f16(tmp11, tmp21), vmulq_n_f16(s31, 1.625)), s61);
vst1q_f16(dst_ptr + 1 * dst_step + 0 * pack_tile, m0);
vst1q_f16(dst_ptr + 1 * dst_step + 1 * pack_tile, m1);
m0 = vaddq_f16(vaddq_f16(vsubq_f16(tmp20, tmp10), vmulq_n_f16(s30, 1.625)), s60);
m1 = vaddq_f16(vaddq_f16(vsubq_f16(tmp21, tmp11), vmulq_n_f16(s31, 1.625)), s61);
vst1q_f16(dst_ptr + 2 * dst_step + 0 * pack_tile, m0);
vst1q_f16(dst_ptr + 2 * dst_step + 1 * pack_tile, m1);
tmp10 = vaddq_f16(vmulq_n_f16(s10, 0.5625), s50);
tmp11 = vaddq_f16(vmulq_n_f16(s11, 0.5625), s51);
tmp20 = vsubq_f16(vmulq_n_f16(s20, 0.5625), vmulq_n_f16(s40, 2.5));
tmp21 = vsubq_f16(vmulq_n_f16(s21, 0.5625), vmulq_n_f16(s41, 2.5));
m0 = vaddq_f16(vsubq_f16(vaddq_f16(tmp10, tmp20), vmulq_n_f16(s30, 2.5)), s60);
m1 = vaddq_f16(vsubq_f16(vaddq_f16(tmp11, tmp21), vmulq_n_f16(s31, 2.5)), s61);
vst1q_f16(dst_ptr + 3 * dst_step + 0 * pack_tile, m0);
vst1q_f16(dst_ptr + 3 * dst_step + 1 * pack_tile, m1);
m0 = vaddq_f16(vaddq_f16(vsubq_f16(tmp20, tmp10), vmulq_n_f16(s30, 2.5)), s60);
m1 = vaddq_f16(vaddq_f16(vsubq_f16(tmp21, tmp11), vmulq_n_f16(s31, 2.5)), s61);
vst1q_f16(dst_ptr + 4 * dst_step + 0 * pack_tile, m0);
vst1q_f16(dst_ptr + 4 * dst_step + 1 * pack_tile, m1);
tmp10 = vaddq_f16(vmulq_n_f16(s10, 0.375), vmulq_n_f16(s50, 1.5));
tmp11 = vaddq_f16(vmulq_n_f16(s11, 0.375), vmulq_n_f16(s51, 1.5));
tmp20 = vsubq_f16(vmulq_n_f16(s20, 0.25), vmulq_n_f16(s40, 1.25));
tmp21 = vsubq_f16(vmulq_n_f16(s21, 0.25), vmulq_n_f16(s41, 1.25));
m0 = vaddq_f16(vsubq_f16(vaddq_f16(tmp10, tmp20), vmulq_n_f16(s30, 1.875)), s60);
m1 = vaddq_f16(vsubq_f16(vaddq_f16(tmp11, tmp21), vmulq_n_f16(s31, 1.875)), s61);
vst1q_f16(dst_ptr + 5 * dst_step + 0 * pack_tile, m0);
vst1q_f16(dst_ptr + 5 * dst_step + 1 * pack_tile, m1);
m0 = vaddq_f16(vaddq_f16(vsubq_f16(tmp20, tmp10), vmulq_n_f16(s30, 1.875)), s60);
m1 = vaddq_f16(vaddq_f16(vsubq_f16(tmp21, tmp11), vmulq_n_f16(s31, 1.875)), s61);
vst1q_f16(dst_ptr + 6 * dst_step + 0 * pack_tile, m0);
vst1q_f16(dst_ptr + 6 * dst_step + 1 * pack_tile, m1);
m0 = vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(s10, -0.5625), vmulq_n_f16(s30, 3.0625)), vmulq_n_f16(s50, 3.5)), s70);
m1 = vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(s11, -0.5625), vmulq_n_f16(s31, 3.0625)), vmulq_n_f16(s51, 3.5)), s71);
vst1q_f16(dst_ptr + 7 * dst_step + 0 * pack_tile, m0);
vst1q_f16(dst_ptr + 7 * dst_step + 1 * pack_tile, m1);
}
void InputTransform8x8Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) {
int block_tile = 16;
int pack_tile = src_step;
int src_point_stride = block_tile * pack_tile;
for (int l = 0; l < 8; ++l) {
float16_t *src_ptr = src_data + l * C8NUM * block_tile;
TRANSPOSE_16x8;
}
for (int c = 0; c < real_c; ++c) {
float16_t *src_ptr = src_data + c * block_tile;
float16_t *dst_ptr = dst_data + c * block_tile;
InputTransform8x8Pack16ChannelFp16(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride);
}
}
#endif
OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActType act_type) {
if (input_unit == 4 && output_unit < 4) {
if (act_type == ActType_Relu) {

View File

@ -29,9 +29,22 @@ extern "C" {
typedef void (*InputTransFp16Func)(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step,
int real_c);
typedef void (*InputTransStepFp16Func)(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step,
int dst_row_step);
typedef void (*InputTransPackFp16Func)(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step,
int real_c);
typedef void (*OutputTransFp16Func)(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c);
typedef struct TransFp16FuncList {
InputTransFp16Func in_func_;
InputTransStepFp16Func in_step_func_;
InputTransPackFp16Func in_pack_func_;
OutputTransFp16Func out_func_;
} TransFp16FuncList;
#define Load16DataFp16 \
src[0] = vld1q_f16(src_data + 0 * src_step); \
src[1] = vld1q_f16(src_data + 1 * src_step); \
@ -276,14 +289,77 @@ typedef void (*OutputTransFp16Func)(const float16_t *src_data, float16_t *dst_da
src[62] = vld1_f16(src_data + 62 * src_step); \
src[63] = vld1_f16(src_data + 63 * src_step);
#define LOAD_LINE_DATA_FP16(line) \
float16x8_t s##line##0 = vld1q_f16(src_ptr + line * src_point_stride + 0 * pack_tile); \
float16x8_t s##line##1 = vld1q_f16(src_ptr + line * src_point_stride + 1 * pack_tile);
#define TRANSPOSE_16x8 \
float16x8_t s0 = vld1q_f16(src_ptr + 0 * pack_tile); \
float16x8_t s2 = vld1q_f16(src_ptr + 1 * pack_tile); \
float16x8_t s4 = vld1q_f16(src_ptr + 2 * pack_tile); \
float16x8_t s6 = vld1q_f16(src_ptr + 3 * pack_tile); \
float16x8_t s8 = vld1q_f16(src_ptr + 4 * pack_tile); \
float16x8_t s10 = vld1q_f16(src_ptr + 5 * pack_tile); \
float16x8_t s12 = vld1q_f16(src_ptr + 6 * pack_tile); \
float16x8_t s14 = vld1q_f16(src_ptr + 7 * pack_tile); \
float16x8_t s1 = vld1q_f16(src_ptr + 8 * pack_tile); \
float16x8_t s3 = vld1q_f16(src_ptr + 9 * pack_tile); \
float16x8_t s5 = vld1q_f16(src_ptr + 10 * pack_tile); \
float16x8_t s7 = vld1q_f16(src_ptr + 11 * pack_tile); \
float16x8_t s9 = vld1q_f16(src_ptr + 12 * pack_tile); \
float16x8_t s11 = vld1q_f16(src_ptr + 13 * pack_tile); \
float16x8_t s13 = vld1q_f16(src_ptr + 14 * pack_tile); \
float16x8_t s15 = vld1q_f16(src_ptr + 15 * pack_tile); \
transpose8(&s0, &s2, &s4, &s6, &s8, &s10, &s12, &s14); \
transpose8(&s1, &s3, &s5, &s7, &s9, &s11, &s13, &s15); \
vst1q_f16(src_ptr + 0 * pack_tile, s0); \
vst1q_f16(src_ptr + 1 * pack_tile, s1); \
vst1q_f16(src_ptr + 2 * pack_tile, s2); \
vst1q_f16(src_ptr + 3 * pack_tile, s3); \
vst1q_f16(src_ptr + 4 * pack_tile, s4); \
vst1q_f16(src_ptr + 5 * pack_tile, s5); \
vst1q_f16(src_ptr + 6 * pack_tile, s6); \
vst1q_f16(src_ptr + 7 * pack_tile, s7); \
vst1q_f16(src_ptr + 8 * pack_tile, s8); \
vst1q_f16(src_ptr + 9 * pack_tile, s9); \
vst1q_f16(src_ptr + 10 * pack_tile, s10); \
vst1q_f16(src_ptr + 11 * pack_tile, s11); \
vst1q_f16(src_ptr + 12 * pack_tile, s12); \
vst1q_f16(src_ptr + 13 * pack_tile, s13); \
vst1q_f16(src_ptr + 14 * pack_tile, s14); \
vst1q_f16(src_ptr + 15 * pack_tile, s15);
InputTransFp16Func GetInputTransFp16Func(int input_unit);
#ifdef ENABLE_ARM64
InputTransStepFp16Func GetInputTransStepFp16Func(int input_unit);
InputTransPackFp16Func GetInputTransPackFp16Func(int input_unit);
#endif
void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c);
void InputTransform6x6UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c);
void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c);
void InputTransform4x4StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step,
int dst_row_step);
void InputTransform6x6StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step,
int dst_row_step);
void InputTransform8x8StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step,
int dst_row_step);
#ifdef ENABLE_ARM64
void InputTransform4x4Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c);
void InputTransform6x6Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c);
void InputTransform8x8Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c);
#endif
OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActType act_type);
#define Store4DataFp16 \

View File

@ -23,11 +23,12 @@
// fp32 conv winograd
void ConvWinogardFp32(const float *input_data, const float *trans_weight, const float *bias_data, float *output_data,
TmpBufferAddress *buffer_list, int task_id, const ConvParameter *conv_param,
InputTransFunc in_func, OutputTransFunc out_func) {
TransFuncList trans_func) {
if (conv_param->output_unit_ == 0) {
return;
}
int in_channel = conv_param->input_channel_;
int input_unit = conv_param->input_unit_;
int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_);
int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_);
int output_count = out_w_block * out_h_block;
@ -35,26 +36,19 @@ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const
int output_tile_count = UP_DIV(output_count, tile_num);
#ifdef ENABLE_AVX
const int col_tile = C16NUM;
const int tmp_data_tile = C8NUM;
const int channel_pack_tile = C8NUM;
#else
const int col_tile = C8NUM;
const int tmp_data_tile = C4NUM;
const int channel_pack_tile = C4NUM;
#endif
int oc_tile = UP_DIV(conv_param->output_channel_, col_tile);
int oc8 = UP_DIV(conv_param->output_channel_, C8NUM);
int input_unit_square = conv_param->input_unit_ * conv_param->input_unit_;
if (input_unit_square < conv_param->input_unit_) {
return;
}
int input_unit_square = input_unit * input_unit;
float *trans_input = buffer_list[0];
float *gemm_out = buffer_list[1];
float *tmp_data = buffer_list[2];
float *col_buffer = buffer_list[3];
int trans_input_offset = tile_num * input_unit_square * in_channel;
int gemm_out_offset = tile_num * input_unit_square * oc8 * C8NUM;
int tmp_data_offset = input_unit_square * tmp_data_tile;
int col_buffer_offset = tile_num * in_channel;
float *trans_input = buffer_list[0] + task_id * tile_num * input_unit_square * in_channel;
float *gemm_out = buffer_list[1] + task_id * tile_num * input_unit_square * oc8 * C8NUM;
float *tmp_data = buffer_list[2] + task_id * input_unit_square * channel_pack_tile;
float *col_buffer = buffer_list[3] + task_id * tile_num * in_channel;
// step 1 : filter transform (pre-processed offline)
// step 2 : input transform (online)
for (int b = 0; b < conv_param->input_batch_; b++) {
@ -67,37 +61,75 @@ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const
if (cal_num <= 0) {
return;
}
WinogradInputTransform(input_data + in_batch_offset, trans_input + task_id * trans_input_offset,
tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param,
in_func);
// step 3 : gemm
float *src_ptr = trans_input + task_id * trans_input_offset;
float *dst_ptr = gemm_out + task_id * gemm_out_offset;
float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset;
for (int i = 0; i < input_unit_square; ++i) {
#ifdef ENABLE_ARM64
// Optimize input transform. Only valid for arm64, the tile num is 12, the channel_tile is 4.
// For arm32, the tile_num is 4.
// For x86_sse, the tile_num is 4, the channel_tile is 4.
// For avx, the tile_num is 6, the channel_tile is 8.
// N = input_unit, M = tile_num
// The function(InputTransformNxNStep, InputTransform4x4PackM) needs to be rewritten.
bool fused_pack =
(cal_num == tile_num) && (trans_func.in_step_func_ != NULL) && (trans_func.in_pack_func_ != NULL);
if (fused_pack) {
float *opt_trans_input =
buffer_list[4] + task_id * tile_num * input_unit_square * UP_ROUND(in_channel, channel_pack_tile);
WinogradInputTransformOptStep(input_data + in_batch_offset, opt_trans_input, tmp_data, cal_num, out_tile_index,
out_w_block, conv_param, trans_func.in_step_func_);
for (int w_index = 0; w_index < input_unit; w_index++) {
float *src_w = opt_trans_input + w_index * input_unit * tile_num * channel_pack_tile;
for (int c = 0; c < UP_DIV(in_channel, channel_pack_tile); c++) {
int real_c = in_channel - c * channel_pack_tile;
real_c = real_c > channel_pack_tile ? channel_pack_tile : real_c;
float *src_c = src_w + c * input_unit_square * tile_num * channel_pack_tile;
float *dst_c = trans_input + c * tile_num * channel_pack_tile;
trans_func.in_pack_func_(src_c, dst_c, channel_pack_tile, in_channel * tile_num, real_c);
}
for (int h_index = 0; h_index < input_unit; h_index++) {
const float *gemm_input = trans_input + h_index * tile_num * in_channel;
int point_index = h_index * input_unit + w_index;
const float *gemm_weight = trans_weight + point_index * in_channel * oc_tile * col_tile;
MatMulOpt(gemm_input, gemm_weight, gemm_out + point_index * C8NUM, NULL, 0, in_channel, cal_num,
oc8 * C8NUM, input_unit_square, OutType_TileC8);
}
}
} else {
#endif
WinogradInputTransform(input_data + in_batch_offset, trans_input, tmp_data, cal_num, out_tile_index,
out_w_block, conv_param, trans_func.in_func_);
// step 3 : gemm
float *src_ptr = trans_input;
float *dst_ptr = gemm_out;
float *tmp_col_ptr = col_buffer;
for (int i = 0; i < input_unit_square; ++i) {
#ifdef ENABLE_AVX
RowMajor2Col6Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel);
RowMajor2Col6Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel);
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
RowMajor2Col4Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel);
#else
RowMajor2Col12Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel);
#endif
MatMulOpt(tmp_col_ptr, trans_weight + i * in_channel * oc_tile * col_tile, dst_ptr + i * C8NUM, NULL, 0,
in_channel, cal_num, oc8 * C8NUM, input_unit_square, 2);
MatMulOpt(tmp_col_ptr, trans_weight + i * in_channel * oc_tile * col_tile, dst_ptr + i * C8NUM, NULL, 0,
in_channel, cal_num, oc8 * C8NUM, input_unit_square, 2);
}
#ifdef ENABLE_ARM64
}
#endif
// step 4 : output transform
float *output_ptr = output_data + out_batch_offset;
if (conv_param->out_format_ != NNACL_NC4HW4) { // nc4hw4
WinogradOutputNHWCTransform(dst_ptr, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param,
out_func);
WinogradOutputNHWCTransform(gemm_out, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param,
trans_func.out_func_);
} else {
#if defined(ENABLE_AVX) || defined(ENABLE_ARM64)
WinogradOutputNC4HW4Transform(dst_ptr, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param,
out_func);
WinogradOutputNC4HW4Transform(gemm_out, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param,
trans_func.out_func_);
#else
WinogradOutputNHWCTransform(dst_ptr, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param,
out_func);
WinogradOutputNHWCTransform(gemm_out, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param,
trans_func.out_func_);
#endif
}
}

View File

@ -36,7 +36,7 @@ extern "C" {
// fp32 convolution winograd
void ConvWinogardFp32(const float *input_data, const float *trans_weight, const float *bias_data, float *output_data,
TmpBufferAddress *buffer_list, int task_id, const ConvParameter *conv_param,
InputTransFunc in_func, OutputTransFunc out_func);
TransFuncList trans_func);
#ifdef __cplusplus
}
#endif

View File

@ -17,6 +17,59 @@
#include "nnacl/fp32/winograd_transform.h"
#include "nnacl/op_base.h"
void PrepareTransInput(const float *src_data, float *dst_data, int interval_x_s, int interval_x_e, int interval_y_s,
int interval_y_e, int real_c, const ConvParameter *conv_param) {
int input_unit = conv_param->input_unit_;
int in_channel = conv_param->input_channel_;
int input_w = conv_param->input_w_;
#ifdef ENABLE_AVX
int channel_tile = C8NUM;
#else
int channel_tile = C4NUM;
#endif
// clear tmp buffer
if (interval_x_e - interval_x_s != input_unit || interval_y_e - interval_y_s != input_unit) {
memset(dst_data, 0, input_unit * input_unit * channel_tile * (int)(sizeof(float)));
}
// get real input block with padding
if (real_c == channel_tile) {
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
int src_y_offset = (interval * input_w + interval_x_s) * in_channel;
int dst_y_offset = interval * input_unit * channel_tile + interval_x_s * channel_tile;
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
int src_x_offset = src_y_offset + j * in_channel;
int dst_x_offset = dst_y_offset + j * channel_tile;
const float *src_addr = src_data + src_x_offset;
float *dst_addr = dst_data + dst_x_offset;
#ifdef ENABLE_AVX
MS_ST256_F32(dst_addr, MS_LD256_F32(src_addr));
#elif defined(ENABLE_ARM) || defined(ENABLE_SSE)
MS_STQ_F32(dst_addr, MS_LDQ_F32(src_addr));
#else
for (int k = 0; k < channel_tile; k++) {
dst_addr[k] = src_addr[k];
}
#endif
} // interval x loop
} // interval y loop
} else {
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
int src_y_offset = (interval * input_w + interval_x_s) * in_channel;
int dst_y_offset = interval * input_unit * channel_tile + interval_x_s * channel_tile;
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
int src_x_offset = src_y_offset + j * in_channel;
int dst_x_offset = dst_y_offset + j * channel_tile;
const float *src_addr = src_data + src_x_offset;
float *dst_addr = dst_data + dst_x_offset;
for (int k = 0; k < real_c; k++) {
dst_addr[k] = src_addr[k];
}
} // interval x loop
} // interval y loop
}
}
// fp32 conv winograd
void WinogradInputTransform(const float *input_data, float *trans_input, float *tmp_data, int cal_num,
int out_tile_index, int out_w_block_num, const ConvParameter *conv_param,
@ -25,11 +78,11 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
int output_unit = conv_param->output_unit_;
int in_channel = conv_param->input_channel_;
#ifdef ENABLE_AVX
int tile = C8NUM;
int channel_tile = C8NUM;
#else
int tile = C4NUM;
int channel_tile = C4NUM;
#endif
int ic4 = UP_DIV(in_channel, tile);
int ic4 = UP_DIV(in_channel, channel_tile);
int pad_h = conv_param->pad_u_;
int pad_w = conv_param->pad_l_;
int input_h = conv_param->input_h_;
@ -49,54 +102,61 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s);
int dst_plane_offset = c * in_channel;
for (int ic = 0; ic < ic4; ic++) {
// clear tmp buffer
memset(tmp_data, 0, input_unit * input_unit * tile * (int)(sizeof(float)));
int real_c = in_channel - ic * channel_tile;
real_c = real_c > channel_tile ? channel_tile : real_c;
const float *src_data = input_data + src_plane_offset + ic * channel_tile;
PrepareTransInput(src_data, tmp_data, interval_x_s, interval_x_e, interval_y_s, interval_y_e, real_c, conv_param);
int real_c = in_channel - ic * tile;
real_c = real_c > tile ? tile : real_c;
int src_ic4_offset = src_plane_offset + ic * tile;
// get real input block with padding
if (real_c == tile) {
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
int src_y_offset = src_ic4_offset + (interval * input_w + interval_x_s) * in_channel;
int dst_y_offset = interval * input_unit * tile + interval_x_s * tile;
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
int src_x_offset = src_y_offset + j * in_channel;
int dst_x_offset = dst_y_offset + j * tile;
float *src_addr = (float *)(input_data) + src_x_offset;
float *dst_addr = tmp_data + dst_x_offset;
#ifdef ENABLE_AVX
MS_ST256_F32(dst_addr, MS_LD256_F32(src_addr));
#elif defined(ENABLE_ARM) || defined(ENABLE_SSE)
MS_STQ_F32(dst_addr, MS_LDQ_F32(src_addr));
#else
for (int k = 0; k < tile; k++) {
dst_addr[k] = src_addr[k];
}
#endif
} // interval x loop
} // interval y loop
} else {
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
int src_y_offset = src_ic4_offset + (interval * input_w + interval_x_s) * in_channel;
int dst_y_offset = interval * input_unit * tile + interval_x_s * tile;
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
int src_x_offset = src_y_offset + j * in_channel;
int dst_x_offset = dst_y_offset + j * tile;
float *src_addr = (float *)(input_data) + src_x_offset;
float *dst_addr = tmp_data + dst_x_offset;
for (int k = 0; k < real_c; k++) {
dst_addr[k] = src_addr[k];
}
} // interval x loop
} // interval y loop
}
// input transform
const int tile_num = C12NUM;
int dst_ic4_offset = dst_plane_offset + ic * tile;
int dst_ic4_offset = dst_plane_offset + ic * channel_tile;
int dst_step = tile_num * in_channel;
float *trans_input_ptr = trans_input + dst_ic4_offset;
func(tmp_data, trans_input_ptr, tile, dst_step, real_c);
func(tmp_data, trans_input_ptr, channel_tile, dst_step, real_c);
}
out_tile_index++;
} // cal_tile_num loop
}
// Only support arm64
void WinogradInputTransformOptStep(const float *input_data, float *trans_input, float *tmp_data, int cal_num,
int out_tile_index, int out_w_block_num, const ConvParameter *conv_param,
InputTransStepFunc func) {
int input_unit = conv_param->input_unit_;
int output_unit = conv_param->output_unit_;
int in_channel = conv_param->input_channel_;
int channel_tile = C4NUM;
int ic4 = UP_DIV(in_channel, channel_tile);
int pad_h = conv_param->pad_u_;
int pad_w = conv_param->pad_l_;
int input_h = conv_param->input_h_;
int input_w = conv_param->input_w_;
NNACL_CHECK_ZERO_RETURN(out_w_block_num);
for (int c = 0; c < cal_num; c++) { // actual tiled number
int src_x_s = (out_tile_index % out_w_block_num) * output_unit - pad_w;
int src_y_s = (out_tile_index / out_w_block_num) * output_unit - pad_h;
int interval_x_s = src_x_s > 0 ? 0 : -src_x_s;
int interval_y_s = src_y_s > 0 ? 0 : -src_y_s;
int src_x_e = src_x_s + input_unit;
int src_y_e = src_y_s + input_unit;
int interval_x_e = src_x_e < input_w ? input_unit : (input_w - src_x_s);
int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s);
int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s);
int dst_plane_offset = c * channel_tile;
for (int ic = 0; ic < ic4; ic++) {
int real_c = in_channel - ic * channel_tile;
real_c = real_c > channel_tile ? channel_tile : real_c;
const float *src_data = input_data + src_plane_offset + ic * channel_tile;
PrepareTransInput(src_data, tmp_data, interval_x_s, interval_x_e, interval_y_s, interval_y_e, real_c, conv_param);
// input transform
const int block_tile = C12NUM;
int dst_ic8_offset = dst_plane_offset + ic * block_tile * input_unit * input_unit * channel_tile;
size_t dst_step = input_unit * block_tile * channel_tile;
float *trans_input_ptr = trans_input + dst_ic8_offset;
func(tmp_data, trans_input_ptr, channel_tile, dst_step, block_tile * channel_tile);
}
out_tile_index++;
} // cal_tile_num loop

View File

@ -32,6 +32,10 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
int out_tile_index, int out_w_block_num, const ConvParameter *conv_param,
InputTransFunc func);
void WinogradInputTransformOptStep(const float *input_data, float *trans_input, float *tmp_data, int cal_num,
int out_tile_index, int out_w_block_num, const ConvParameter *conv_param,
InputTransStepFunc func);
void WinogradOutputNHWCTransform(const float *gemm_out, float *out_data, const float *bias_data, int cal_num,
int out_tile_index, int output_unit_num, const ConvParameter *conv_param,
OutputTransFunc func);

View File

@ -20,6 +20,19 @@
#include "nnacl/base/conv_common_base.h"
#include "nnacl/errorcode.h"
#ifdef ENABLE_ARM64
void transpose4(MS_FLOAT32X4 *s0, MS_FLOAT32X4 *s1, MS_FLOAT32X4 *s2, MS_FLOAT32X4 *s3) {
float64x2_t m0 = (float64x2_t)(vtrn1q_f32(*s0, *s1));
float64x2_t m1 = (float64x2_t)(vtrn2q_f32(*s0, *s1));
float64x2_t m2 = (float64x2_t)(vtrn1q_f32(*s2, *s3));
float64x2_t m3 = (float64x2_t)(vtrn2q_f32(*s2, *s3));
*s0 = (float32x4_t)(vtrn1q_f64(m0, m2));
*s2 = (float32x4_t)(vtrn2q_f64(m0, m2));
*s1 = (float32x4_t)(vtrn1q_f64(m1, m3));
*s3 = (float32x4_t)(vtrn2q_f64(m1, m3));
}
#endif
#ifdef ENABLE_AVX
static InputTransFunc InputTransFuncList[] = {
NULL, NULL, NULL, NULL, InputTransform4x4AvxUnit, NULL, InputTransform6x6AvxUnit, NULL, InputTransform8x8AvxUnit};
@ -55,6 +68,18 @@ static OutputTransFunc OutputTransFuncList[] = {
InputTransFunc GetInputTransFunc(int input_unit) { return InputTransFuncList[input_unit]; }
#ifdef ENABLE_ARM64
static InputTransStepFunc InputTransStepFuncList[] = {
NULL, NULL, NULL, NULL, InputTransform4x4Step, NULL, InputTransform6x6Step, NULL, InputTransform8x8Step};
static InputTransPackFunc InputTransPackFuncList[] = {
NULL, NULL, NULL, NULL, InputTransform4x4Pack12, NULL, InputTransform6x6Pack12, NULL, InputTransform8x8Pack12};
InputTransStepFunc GetInputTransStepFunc(int input_unit) { return InputTransStepFuncList[input_unit]; }
InputTransPackFunc GetInputTransPackFunc(int input_unit) { return InputTransPackFuncList[input_unit]; }
#endif
void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) {
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
if (real_c == 4) {
@ -138,6 +163,136 @@ void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step,
#endif
}
void InputTransform4x4Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step) {
#ifdef ENABLE_ARM64
for (int l = 0; l < 4; ++l) {
const float *src_ptr = src_data + l * 4 * src_step;
float *dst_ptr = dst_data + l * dst_row_step;
MS_FLOAT32X4 s0 = MS_LDQ_F32(src_ptr + 0 * src_step);
MS_FLOAT32X4 s1 = MS_LDQ_F32(src_ptr + 1 * src_step);
MS_FLOAT32X4 s2 = MS_LDQ_F32(src_ptr + 2 * src_step);
MS_FLOAT32X4 s3 = MS_LDQ_F32(src_ptr + 3 * src_step);
MS_FLOAT32X4 m0 = MS_SUBQ_F32(s0, s2);
MS_FLOAT32X4 m1 = MS_ADDQ_F32(s1, s2);
MS_FLOAT32X4 m2 = MS_SUBQ_F32(s2, s1);
MS_FLOAT32X4 m3 = MS_SUBQ_F32(s3, s1);
MS_STQ_F32(dst_ptr + 0 * dst_step, m0);
MS_STQ_F32(dst_ptr + 1 * dst_step, m1);
MS_STQ_F32(dst_ptr + 2 * dst_step, m2);
MS_STQ_F32(dst_ptr + 3 * dst_step, m3);
}
#else
float src[4];
float m[4];
for (int i = 0; i < C4NUM; ++i) {
for (int l = 0; l < 4; ++l) {
for (int w = 0; w < 4; ++w) {
int tmp_index = l * 4 + w;
src[w] = src_data[i + tmp_index * src_step];
}
m[0] = src[0] - src[2];
m[1] = src[1] + src[2];
m[2] = src[2] - src[1];
m[3] = src[3] - src[1];
float *dst = dst_data + l * dst_row_step;
for (int w = 0; w < 4; ++w) {
dst[i + w * dst_step] = m[w];
}
}
}
#endif
}
#ifdef ENABLE_ARM64
void InputTransform4x4Pack12Channel(float *src_ptr, float *dst_ptr, int dst_step, int pack_tile, int src_point_stride) {
LOAD_LINE_DATA(0);
LOAD_LINE_DATA(1);
LOAD_LINE_DATA(2);
LOAD_LINE_DATA(3);
MS_FLOAT32X4 m0 = MS_SUBQ_F32(s00, s20);
MS_FLOAT32X4 m1 = MS_SUBQ_F32(s01, s21);
MS_FLOAT32X4 m2 = MS_SUBQ_F32(s02, s22);
MS_STQ_F32(dst_ptr + 0 * dst_step + 0 * pack_tile, m0);
MS_STQ_F32(dst_ptr + 0 * dst_step + 1 * pack_tile, m1);
MS_STQ_F32(dst_ptr + 0 * dst_step + 2 * pack_tile, m2);
m0 = MS_ADDQ_F32(s10, s20);
m1 = MS_ADDQ_F32(s11, s21);
m2 = MS_ADDQ_F32(s12, s22);
MS_STQ_F32(dst_ptr + 1 * dst_step + 0 * pack_tile, m0);
MS_STQ_F32(dst_ptr + 1 * dst_step + 1 * pack_tile, m1);
MS_STQ_F32(dst_ptr + 1 * dst_step + 2 * pack_tile, m2);
m0 = MS_SUBQ_F32(s20, s10);
m1 = MS_SUBQ_F32(s21, s11);
m2 = MS_SUBQ_F32(s22, s12);
MS_STQ_F32(dst_ptr + 2 * dst_step + 0 * pack_tile, m0);
MS_STQ_F32(dst_ptr + 2 * dst_step + 1 * pack_tile, m1);
MS_STQ_F32(dst_ptr + 2 * dst_step + 2 * pack_tile, m2);
m0 = MS_SUBQ_F32(s30, s10);
m1 = MS_SUBQ_F32(s31, s11);
m2 = MS_SUBQ_F32(s32, s12);
MS_STQ_F32(dst_ptr + 3 * dst_step + 0 * pack_tile, m0);
MS_STQ_F32(dst_ptr + 3 * dst_step + 1 * pack_tile, m1);
MS_STQ_F32(dst_ptr + 3 * dst_step + 2 * pack_tile, m2);
}
#endif
void InputTransform4x4Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c) {
int block_tile = 12;
int pack_tile = src_step;
int src_point_stride = block_tile * pack_tile;
#ifdef ENABLE_ARM64
for (int l = 0; l < 4; ++l) {
float *src_ptr = src_data + l * C4NUM * block_tile;
TRANSPOSE_12x4;
}
for (int c = 0; c < real_c; ++c) {
float *src_ptr = src_data + c * block_tile;
float *dst_ptr = dst_data + c * block_tile;
InputTransform4x4Pack12Channel(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride);
}
#else
for (int l = 0; l < 4; ++l) {
float *src = src_data + l * pack_tile * block_tile;
// 12 * 4 -> 4 * 12
float tmp_mat[pack_tile][block_tile];
for (int i = 0; i < block_tile; ++i) {
for (int j = 0; j < pack_tile; ++j) {
tmp_mat[j][i] = src[i * pack_tile + j];
}
}
memcpy(src, tmp_mat, pack_tile * block_tile * sizeof(float));
}
float src[4];
float m[4];
for (int c = 0; c < real_c; ++c) {
for (int i = 0; i < block_tile; ++i) {
int tmp_index = c * block_tile + i;
for (int w = 0; w < 4; ++w) {
src[w] = src_data[tmp_index + w * src_point_stride];
}
m[0] = src[0] - src[2];
m[1] = src[1] + src[2];
m[2] = src[2] - src[1];
m[3] = src[3] - src[1];
for (int w = 0; w < 4; ++w) {
dst_data[tmp_index + w * dst_step] = m[w];
}
}
}
#endif
}
void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) {
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
if (real_c == 4) {
@ -217,6 +372,169 @@ void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step,
#endif
}
void InputTransform6x6Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step) {
#ifdef ENABLE_ARM64
for (int l = 0; l < 6; ++l) {
const float *src_ptr = src_data + l * 6 * src_step;
float *dst_ptr = dst_data + l * dst_row_step;
MS_FLOAT32X4 s0 = MS_LDQ_F32(src_ptr + 0 * src_step);
MS_FLOAT32X4 s1 = MS_LDQ_F32(src_ptr + 1 * src_step);
MS_FLOAT32X4 s2 = MS_LDQ_F32(src_ptr + 2 * src_step);
MS_FLOAT32X4 s3 = MS_LDQ_F32(src_ptr + 3 * src_step);
MS_FLOAT32X4 s4 = MS_LDQ_F32(src_ptr + 4 * src_step);
MS_FLOAT32X4 s5 = MS_LDQ_F32(src_ptr + 5 * src_step);
MS_FLOAT32X4 tmp1 = MS_SUBQ_F32(s3, s1);
MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(s4, s2);
MS_FLOAT32X4 m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s0, 4), MS_MULQ_N_F32(s2, 5)), s4);
MS_FLOAT32X4 m1 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(s1, s2), -4), MS_ADDQ_F32(s3, s4));
MS_FLOAT32X4 m2 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s1, s2), 4), MS_SUBQ_F32(s4, s3));
MS_FLOAT32X4 m3 = MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 2), tmp2);
MS_FLOAT32X4 m4 = MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, -2), tmp2);
MS_FLOAT32X4 m5 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s1, 4), MS_MULQ_N_F32(s3, 5)), s5);
MS_STQ_F32(dst_ptr + 0 * dst_step, m0);
MS_STQ_F32(dst_ptr + 1 * dst_step, m1);
MS_STQ_F32(dst_ptr + 2 * dst_step, m2);
MS_STQ_F32(dst_ptr + 3 * dst_step, m3);
MS_STQ_F32(dst_ptr + 4 * dst_step, m4);
MS_STQ_F32(dst_ptr + 5 * dst_step, m5);
}
#else
float src[6];
float m[6];
for (int i = 0; i < C4NUM; ++i) {
for (int l = 0; l < 6; ++l) {
for (int w = 0; w < 6; ++w) {
int tmp_index = l * 6 + w;
src[w] = src_data[i + tmp_index * src_step];
}
float tmp1 = src[3] - src[1];
float tmp2 = src[4] - src[2];
m[0] = 4 * src[0] - 5 * src[2] + src[4];
m[1] = -4 * (src[1] + src[2]) + (src[3] + src[4]);
m[2] = 4 * (src[1] - src[2]) + (src[4] - src[3]);
m[3] = 2 * tmp1 + tmp2;
m[4] = -2 * tmp1 + tmp2;
m[5] = 4 * src[1] - 5 * src[3] + src[5];
float *dst = dst_data + l * dst_row_step;
for (int w = 0; w < 6; ++w) {
dst[i + w * dst_step] = m[w];
}
}
}
#endif
}
#ifdef ENABLE_ARM64
void InputTransform6x6Pack12Channel(float *src_ptr, float *dst_ptr, int dst_step, int pack_tile, int src_point_stride) {
LOAD_LINE_DATA(0);
LOAD_LINE_DATA(1);
LOAD_LINE_DATA(2);
LOAD_LINE_DATA(3);
LOAD_LINE_DATA(4);
LOAD_LINE_DATA(5);
MS_FLOAT32X4 m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s00, 4), MS_MULQ_N_F32(s20, 5)), s40);
MS_FLOAT32X4 m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s01, 4), MS_MULQ_N_F32(s21, 5)), s41);
MS_FLOAT32X4 m2 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s02, 4), MS_MULQ_N_F32(s22, 5)), s42);
MS_STQ_F32(dst_ptr + 0 * dst_step + 0 * pack_tile, m0);
MS_STQ_F32(dst_ptr + 0 * dst_step + 1 * pack_tile, m1);
MS_STQ_F32(dst_ptr + 0 * dst_step + 2 * pack_tile, m2);
m0 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(s10, s20), -4), MS_ADDQ_F32(s30, s40));
m1 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(s11, s21), -4), MS_ADDQ_F32(s31, s41));
m2 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(s12, s22), -4), MS_ADDQ_F32(s32, s42));
MS_STQ_F32(dst_ptr + 1 * dst_step + 0 * pack_tile, m0);
MS_STQ_F32(dst_ptr + 1 * dst_step + 1 * pack_tile, m1);
MS_STQ_F32(dst_ptr + 1 * dst_step + 2 * pack_tile, m2);
m0 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s10, s20), 4), MS_SUBQ_F32(s40, s30));
m1 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s11, s21), 4), MS_SUBQ_F32(s41, s31));
m2 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s12, s22), 4), MS_SUBQ_F32(s42, s32));
MS_STQ_F32(dst_ptr + 2 * dst_step + 0 * pack_tile, m0);
MS_STQ_F32(dst_ptr + 2 * dst_step + 1 * pack_tile, m1);
MS_STQ_F32(dst_ptr + 2 * dst_step + 2 * pack_tile, m2);
m0 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s30, s10), 2), MS_SUBQ_F32(s40, s20));
m1 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s31, s11), 2), MS_SUBQ_F32(s41, s21));
m2 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s32, s12), 2), MS_SUBQ_F32(s42, s22));
MS_STQ_F32(dst_ptr + 3 * dst_step + 0 * pack_tile, m0);
MS_STQ_F32(dst_ptr + 3 * dst_step + 1 * pack_tile, m1);
MS_STQ_F32(dst_ptr + 3 * dst_step + 2 * pack_tile, m2);
m0 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s30, s10), -2), MS_SUBQ_F32(s40, s20));
m1 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s31, s11), -2), MS_SUBQ_F32(s41, s21));
m2 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s32, s12), -2), MS_SUBQ_F32(s42, s22));
MS_STQ_F32(dst_ptr + 4 * dst_step + 0 * pack_tile, m0);
MS_STQ_F32(dst_ptr + 4 * dst_step + 1 * pack_tile, m1);
MS_STQ_F32(dst_ptr + 4 * dst_step + 2 * pack_tile, m2);
m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s10, 4), MS_MULQ_N_F32(s30, 5)), s50);
m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s11, 4), MS_MULQ_N_F32(s31, 5)), s51);
m2 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s12, 4), MS_MULQ_N_F32(s32, 5)), s52);
MS_STQ_F32(dst_ptr + 5 * dst_step + 0 * pack_tile, m0);
MS_STQ_F32(dst_ptr + 5 * dst_step + 1 * pack_tile, m1);
MS_STQ_F32(dst_ptr + 5 * dst_step + 2 * pack_tile, m2);
}
#endif
void InputTransform6x6Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c) {
int block_tile = 12;
int pack_tile = src_step;
int src_point_stride = block_tile * pack_tile;
#ifdef ENABLE_ARM64
for (int l = 0; l < 6; ++l) {
float *src_ptr = src_data + l * C4NUM * block_tile;
TRANSPOSE_12x4;
}
for (int c = 0; c < real_c; ++c) {
float *src_ptr = src_data + c * block_tile;
float *dst_ptr = dst_data + c * block_tile;
InputTransform6x6Pack12Channel(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride);
}
#else
for (int l = 0; l < 6; ++l) {
float *src = src_data + l * pack_tile * block_tile;
// 12 * 4 -> 4 * 12
float tmp_mat[pack_tile][block_tile];
for (int i = 0; i < block_tile; ++i) {
for (int j = 0; j < pack_tile; ++j) {
tmp_mat[j][i] = src[i * pack_tile + j];
}
}
memcpy(src, tmp_mat, pack_tile * block_tile * sizeof(float));
}
float src[6];
float m[6];
for (int c = 0; c < real_c; ++c) {
for (int i = 0; i < block_tile; ++i) {
int tmp_index = c * block_tile + i;
for (int w = 0; w < 6; ++w) {
src[w] = src_data[tmp_index + w * src_point_stride];
}
float tmp1 = src[3] - src[1];
float tmp2 = src[4] - src[2];
m[0] = 4 * src[0] - 5 * src[2] + src[4];
m[1] = -4 * (src[1] + src[2]) + (src[3] + src[4]);
m[2] = 4 * (src[1] - src[2]) + (src[4] - src[3]);
m[3] = 2 * tmp1 + tmp2;
m[4] = -2 * tmp1 + tmp2;
m[5] = 4 * src[1] - 5 * src[3] + src[5];
for (int w = 0; w < 6; ++w) {
dst_data[tmp_index + w * dst_step] = m[w];
}
}
}
#endif
}
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
void InputTransform8x8Unit_block4(const float *src_data, float *dst_data, int src_step, int dst_step) {
MS_FLOAT32X4 src[64];
@ -334,6 +652,232 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step,
#endif
}
void InputTransform8x8Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step) {
#ifdef ENABLE_ARM64
for (int l = 0; l < 8; ++l) {
const float *src_ptr = src_data + l * 8 * src_step;
float *dst_ptr = dst_data + l * dst_row_step;
MS_FLOAT32X4 s0 = MS_LDQ_F32(src_ptr + 0 * src_step);
MS_FLOAT32X4 s1 = MS_LDQ_F32(src_ptr + 1 * src_step);
MS_FLOAT32X4 s2 = MS_LDQ_F32(src_ptr + 2 * src_step);
MS_FLOAT32X4 s3 = MS_LDQ_F32(src_ptr + 3 * src_step);
MS_FLOAT32X4 s4 = MS_LDQ_F32(src_ptr + 4 * src_step);
MS_FLOAT32X4 s5 = MS_LDQ_F32(src_ptr + 5 * src_step);
MS_FLOAT32X4 s6 = MS_LDQ_F32(src_ptr + 6 * src_step);
MS_FLOAT32X4 s7 = MS_LDQ_F32(src_ptr + 7 * src_step);
MS_FLOAT32X4 m0 = MS_SUBQ_F32(
MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s0, 0.5625), MS_MULQ_N_F32(s2, 3.0625)), MS_MULQ_N_F32(s4, 3.5)), s6);
MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(s1, 1.125), MS_MULQ_N_F32(s5, 0.5));
MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(s2, 2.25), MS_MULQ_N_F32(s4, 3.25));
MS_FLOAT32X4 m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(s3, 1.625)), s6);
MS_FLOAT32X4 m2 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(s3, 1.625)), s6);
tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(s1, 0.5625), s5);
tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(s2, 0.5625), MS_MULQ_N_F32(s4, 2.5));
MS_FLOAT32X4 m3 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(s3, 2.5)), s6);
MS_FLOAT32X4 m4 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(s3, 2.5)), s6);
tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(s1, 0.375), MS_MULQ_N_F32(s5, 1.5));
tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(s2, 0.25), MS_MULQ_N_F32(s4, 1.25));
MS_FLOAT32X4 m5 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(s3, 1.875)), s6);
MS_FLOAT32X4 m6 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(s3, 1.875)), s6);
MS_FLOAT32X4 m7 = MS_ADDQ_F32(
MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(s1, -0.5625), MS_MULQ_N_F32(s3, 3.0625)), MS_MULQ_N_F32(s5, 3.5)), s7);
MS_STQ_F32(dst_ptr + 0 * dst_step, m0);
MS_STQ_F32(dst_ptr + 1 * dst_step, m1);
MS_STQ_F32(dst_ptr + 2 * dst_step, m2);
MS_STQ_F32(dst_ptr + 3 * dst_step, m3);
MS_STQ_F32(dst_ptr + 4 * dst_step, m4);
MS_STQ_F32(dst_ptr + 5 * dst_step, m5);
MS_STQ_F32(dst_ptr + 6 * dst_step, m6);
MS_STQ_F32(dst_ptr + 7 * dst_step, m7);
}
#else
float src[8];
float m[8];
for (int i = 0; i < C4NUM; ++i) {
for (int l = 0; l < 8; ++l) {
for (int w = 0; w < 8; ++w) {
int tmp_index = l * 8 + w;
src[w] = src_data[i + tmp_index * src_step];
}
m[0] = 0.5625f * src[0] - 3.0625f * src[2] + 3.5f * src[4] - src[6];
float tmp1 = 1.125f * src[1] + 0.5f * src[5];
float tmp2 = 2.25f * src[2] - 3.25f * src[4];
m[1] = tmp1 + tmp2 - 1.625f * src[3] + src[6];
m[2] = tmp2 - tmp1 + 1.625f * src[3] + src[6];
tmp1 = 0.5625f * src[1] + src[5];
tmp2 = 0.5625f * src[2] - 2.5f * src[4];
m[3] = tmp1 + tmp2 - 2.5f * src[3] + src[6];
m[4] = tmp2 - tmp1 + 2.5f * src[3] + src[6];
tmp1 = 0.375f * src[1] + 1.5f * src[5];
tmp2 = 0.25f * src[2] - 1.25f * src[4];
m[5] = tmp1 + tmp2 - 1.875f * src[3] + src[6];
m[6] = tmp2 - tmp1 + 1.875f * src[3] + src[6];
m[7] = -0.5625f * src[1] + 3.0625f * src[3] - 3.5f * src[5] + src[7];
float *dst = dst_data + l * dst_row_step;
for (int w = 0; w < 8; ++w) {
dst[i + w * dst_step] = m[w];
}
}
}
#endif
}
#ifdef ENABLE_ARM64
void InputTransform8x8Pack12Channel(float *src_ptr, float *dst_ptr, int dst_step, int pack_tile, int src_point_stride) {
LOAD_LINE_DATA(0);
LOAD_LINE_DATA(1);
LOAD_LINE_DATA(2);
LOAD_LINE_DATA(3);
LOAD_LINE_DATA(4);
LOAD_LINE_DATA(5);
LOAD_LINE_DATA(6);
LOAD_LINE_DATA(7);
MS_FLOAT32X4 m0 = MS_SUBQ_F32(
MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s00, 0.5625), MS_MULQ_N_F32(s20, 3.0625)), MS_MULQ_N_F32(s40, 3.5)), s60);
MS_FLOAT32X4 m1 = MS_SUBQ_F32(
MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s01, 0.5625), MS_MULQ_N_F32(s21, 3.0625)), MS_MULQ_N_F32(s41, 3.5)), s61);
MS_FLOAT32X4 m2 = MS_SUBQ_F32(
MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s02, 0.5625), MS_MULQ_N_F32(s22, 3.0625)), MS_MULQ_N_F32(s42, 3.5)), s62);
MS_STQ_F32(dst_ptr + 0 * dst_step + 0 * pack_tile, m0);
MS_STQ_F32(dst_ptr + 0 * dst_step + 1 * pack_tile, m1);
MS_STQ_F32(dst_ptr + 0 * dst_step + 2 * pack_tile, m2);
MS_FLOAT32X4 tmp10 = MS_ADDQ_F32(MS_MULQ_N_F32(s10, 1.125), MS_MULQ_N_F32(s50, 0.5));
MS_FLOAT32X4 tmp11 = MS_ADDQ_F32(MS_MULQ_N_F32(s11, 1.125), MS_MULQ_N_F32(s51, 0.5));
MS_FLOAT32X4 tmp12 = MS_ADDQ_F32(MS_MULQ_N_F32(s12, 1.125), MS_MULQ_N_F32(s52, 0.5));
MS_FLOAT32X4 tmp20 = MS_SUBQ_F32(MS_MULQ_N_F32(s20, 2.25), MS_MULQ_N_F32(s40, 3.25));
MS_FLOAT32X4 tmp21 = MS_SUBQ_F32(MS_MULQ_N_F32(s21, 2.25), MS_MULQ_N_F32(s41, 3.25));
MS_FLOAT32X4 tmp22 = MS_SUBQ_F32(MS_MULQ_N_F32(s22, 2.25), MS_MULQ_N_F32(s42, 3.25));
m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp10, tmp20), MS_MULQ_N_F32(s30, 1.625)), s60);
m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp11, tmp21), MS_MULQ_N_F32(s31, 1.625)), s61);
m2 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp12, tmp22), MS_MULQ_N_F32(s32, 1.625)), s62);
MS_STQ_F32(dst_ptr + 1 * dst_step + 0 * pack_tile, m0);
MS_STQ_F32(dst_ptr + 1 * dst_step + 1 * pack_tile, m1);
MS_STQ_F32(dst_ptr + 1 * dst_step + 2 * pack_tile, m2);
m0 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp20, tmp10), MS_MULQ_N_F32(s30, 1.625)), s60);
m1 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp21, tmp11), MS_MULQ_N_F32(s31, 1.625)), s61);
m2 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp22, tmp12), MS_MULQ_N_F32(s32, 1.625)), s62);
MS_STQ_F32(dst_ptr + 2 * dst_step + 0 * pack_tile, m0);
MS_STQ_F32(dst_ptr + 2 * dst_step + 1 * pack_tile, m1);
MS_STQ_F32(dst_ptr + 2 * dst_step + 2 * pack_tile, m2);
tmp10 = MS_ADDQ_F32(MS_MULQ_N_F32(s10, 0.5625), s50);
tmp11 = MS_ADDQ_F32(MS_MULQ_N_F32(s11, 0.5625), s51);
tmp12 = MS_ADDQ_F32(MS_MULQ_N_F32(s12, 0.5625), s52);
tmp20 = MS_SUBQ_F32(MS_MULQ_N_F32(s20, 0.5625), MS_MULQ_N_F32(s40, 2.5));
tmp21 = MS_SUBQ_F32(MS_MULQ_N_F32(s21, 0.5625), MS_MULQ_N_F32(s41, 2.5));
tmp22 = MS_SUBQ_F32(MS_MULQ_N_F32(s22, 0.5625), MS_MULQ_N_F32(s42, 2.5));
m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp10, tmp20), MS_MULQ_N_F32(s30, 2.5)), s60);
m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp11, tmp21), MS_MULQ_N_F32(s31, 2.5)), s61);
m2 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp12, tmp22), MS_MULQ_N_F32(s32, 2.5)), s62);
MS_STQ_F32(dst_ptr + 3 * dst_step + 0 * pack_tile, m0);
MS_STQ_F32(dst_ptr + 3 * dst_step + 1 * pack_tile, m1);
MS_STQ_F32(dst_ptr + 3 * dst_step + 2 * pack_tile, m2);
m0 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp20, tmp10), MS_MULQ_N_F32(s30, 2.5)), s60);
m1 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp21, tmp11), MS_MULQ_N_F32(s31, 2.5)), s61);
m2 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp22, tmp12), MS_MULQ_N_F32(s32, 2.5)), s62);
MS_STQ_F32(dst_ptr + 4 * dst_step + 0 * pack_tile, m0);
MS_STQ_F32(dst_ptr + 4 * dst_step + 1 * pack_tile, m1);
MS_STQ_F32(dst_ptr + 4 * dst_step + 2 * pack_tile, m2);
tmp10 = MS_ADDQ_F32(MS_MULQ_N_F32(s10, 0.375), MS_MULQ_N_F32(s50, 1.5));
tmp11 = MS_ADDQ_F32(MS_MULQ_N_F32(s11, 0.375), MS_MULQ_N_F32(s51, 1.5));
tmp12 = MS_ADDQ_F32(MS_MULQ_N_F32(s12, 0.375), MS_MULQ_N_F32(s52, 1.5));
tmp20 = MS_SUBQ_F32(MS_MULQ_N_F32(s20, 0.25), MS_MULQ_N_F32(s40, 1.25));
tmp21 = MS_SUBQ_F32(MS_MULQ_N_F32(s21, 0.25), MS_MULQ_N_F32(s41, 1.25));
tmp22 = MS_SUBQ_F32(MS_MULQ_N_F32(s22, 0.25), MS_MULQ_N_F32(s42, 1.25));
m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp10, tmp20), MS_MULQ_N_F32(s30, 1.875)), s60);
m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp11, tmp21), MS_MULQ_N_F32(s31, 1.875)), s61);
m2 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp12, tmp22), MS_MULQ_N_F32(s32, 1.875)), s62);
MS_STQ_F32(dst_ptr + 5 * dst_step + 0 * pack_tile, m0);
MS_STQ_F32(dst_ptr + 5 * dst_step + 1 * pack_tile, m1);
MS_STQ_F32(dst_ptr + 5 * dst_step + 2 * pack_tile, m2);
m0 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp20, tmp10), MS_MULQ_N_F32(s30, 1.875)), s60);
m1 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp21, tmp11), MS_MULQ_N_F32(s31, 1.875)), s61);
m2 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp22, tmp12), MS_MULQ_N_F32(s32, 1.875)), s62);
MS_STQ_F32(dst_ptr + 6 * dst_step + 0 * pack_tile, m0);
MS_STQ_F32(dst_ptr + 6 * dst_step + 1 * pack_tile, m1);
MS_STQ_F32(dst_ptr + 6 * dst_step + 2 * pack_tile, m2);
m0 = MS_ADDQ_F32(
MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(s10, -0.5625), MS_MULQ_N_F32(s30, 3.0625)), MS_MULQ_N_F32(s50, 3.5)), s70);
m1 = MS_ADDQ_F32(
MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(s11, -0.5625), MS_MULQ_N_F32(s31, 3.0625)), MS_MULQ_N_F32(s51, 3.5)), s71);
m2 = MS_ADDQ_F32(
MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(s12, -0.5625), MS_MULQ_N_F32(s32, 3.0625)), MS_MULQ_N_F32(s52, 3.5)), s72);
MS_STQ_F32(dst_ptr + 7 * dst_step + 0 * pack_tile, m0);
MS_STQ_F32(dst_ptr + 7 * dst_step + 1 * pack_tile, m1);
MS_STQ_F32(dst_ptr + 7 * dst_step + 2 * pack_tile, m2);
}
#endif
void InputTransform8x8Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c) {
int block_tile = 12;
int pack_tile = src_step;
int src_point_stride = block_tile * pack_tile;
#ifdef ENABLE_ARM64
for (int l = 0; l < 8; ++l) {
float *src_ptr = src_data + l * C4NUM * block_tile;
TRANSPOSE_12x4;
}
for (int c = 0; c < real_c; ++c) {
float *src_ptr = src_data + c * block_tile;
float *dst_ptr = dst_data + c * block_tile;
InputTransform8x8Pack12Channel(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride);
}
#else
for (int l = 0; l < 8; ++l) {
float *src = src_data + l * pack_tile * block_tile;
// 12 * 4 -> 4 * 12
float tmp_mat[pack_tile][block_tile];
for (int i = 0; i < block_tile; ++i) {
for (int j = 0; j < pack_tile; ++j) {
tmp_mat[j][i] = src[i * pack_tile + j];
}
}
memcpy(src, tmp_mat, pack_tile * block_tile * sizeof(float));
}
float src[8];
float m[8];
for (int c = 0; c < real_c; ++c) {
for (int i = 0; i < block_tile; ++i) {
int tmp_index = c * block_tile + i;
for (int w = 0; w < 8; ++w) {
src[w] = src_data[tmp_index + w * src_point_stride];
}
m[0] = 0.5625f * src[0] - 3.0625f * src[2] + 3.5f * src[4] - src[6];
float tmp1 = 1.125f * src[1] + 0.5f * src[5];
float tmp2 = 2.25f * src[2] - 3.25f * src[4];
m[1] = tmp1 + tmp2 - 1.625f * src[3] + src[6];
m[2] = tmp2 - tmp1 + 1.625f * src[3] + src[6];
tmp1 = 0.5625f * src[1] + src[5];
tmp2 = 0.5625f * src[2] - 2.5f * src[4];
m[3] = tmp1 + tmp2 - 2.5f * src[3] + src[6];
m[4] = tmp2 - tmp1 + 2.5f * src[3] + src[6];
tmp1 = 0.375f * src[1] + 1.5f * src[5];
tmp2 = 0.25f * src[2] - 1.25f * src[4];
m[5] = tmp1 + tmp2 - 1.875f * src[3] + src[6];
m[6] = tmp2 - tmp1 + 1.875f * src[3] + src[6];
m[7] = -0.5625f * src[1] + 3.0625f * src[3] - 3.5f * src[5] + src[7];
for (int w = 0; w < 8; ++w) {
dst_data[tmp_index + w * dst_step] = m[w];
}
}
}
#endif
}
OutputTransFunc GetOutputTransFunc(int input_unit, int output_unit, ActType act_type) {
if (!CheckWinogradInputOutputUnit(input_unit, output_unit)) {
return NULL;

View File

@ -28,9 +28,21 @@ extern "C" {
#endif
typedef void (*InputTransFunc)(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
typedef void (*InputTransStepFunc)(const float *src_data, float *dst_data, int src_step, int dst_step,
int dst_row_step);
typedef void (*InputTransPackFunc)(float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
typedef void (*OutputTransFunc)(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step, int out_c, int r_w, int r_h, int r_c);
typedef struct TransFuncList {
InputTransFunc in_func_;
InputTransStepFunc in_step_func_;
InputTransPackFunc in_pack_func_;
OutputTransFunc out_func_;
} TransFuncList;
#define Load16Data \
src[0] = MS_LDQ_F32(src_data + 0 * src_step); \
src[1] = MS_LDQ_F32(src_data + 1 * src_step); \
@ -153,14 +165,66 @@ typedef void (*OutputTransFunc)(const float *src_data, float *dst_data, const fl
src[62] = MS_LDQ_F32(src_data + 62 * src_step); \
src[63] = MS_LDQ_F32(src_data + 63 * src_step);
#define LOAD_LINE_DATA(line) \
MS_FLOAT32X4 s##line##0 = MS_LDQ_F32(src_ptr + line * src_point_stride + 0 * pack_tile); \
MS_FLOAT32X4 s##line##1 = MS_LDQ_F32(src_ptr + line * src_point_stride + 1 * pack_tile); \
MS_FLOAT32X4 s##line##2 = MS_LDQ_F32(src_ptr + line * src_point_stride + 2 * pack_tile);
#define TRANSPOSE_12x4 \
MS_FLOAT32X4 s0 = MS_LDQ_F32(src_ptr + 0 * pack_tile); \
MS_FLOAT32X4 s3 = MS_LDQ_F32(src_ptr + 1 * pack_tile); \
MS_FLOAT32X4 s6 = MS_LDQ_F32(src_ptr + 2 * pack_tile); \
MS_FLOAT32X4 s9 = MS_LDQ_F32(src_ptr + 3 * pack_tile); \
MS_FLOAT32X4 s1 = MS_LDQ_F32(src_ptr + 4 * pack_tile); \
MS_FLOAT32X4 s4 = MS_LDQ_F32(src_ptr + 5 * pack_tile); \
MS_FLOAT32X4 s7 = MS_LDQ_F32(src_ptr + 6 * pack_tile); \
MS_FLOAT32X4 s10 = MS_LDQ_F32(src_ptr + 7 * pack_tile); \
MS_FLOAT32X4 s2 = MS_LDQ_F32(src_ptr + 8 * pack_tile); \
MS_FLOAT32X4 s5 = MS_LDQ_F32(src_ptr + 9 * pack_tile); \
MS_FLOAT32X4 s8 = MS_LDQ_F32(src_ptr + 10 * pack_tile); \
MS_FLOAT32X4 s11 = MS_LDQ_F32(src_ptr + 11 * pack_tile); \
transpose4(&s0, &s3, &s6, &s9); \
transpose4(&s1, &s4, &s7, &s10); \
transpose4(&s2, &s5, &s8, &s11); \
MS_STQ_F32(src_ptr + 0 * pack_tile, s0); \
MS_STQ_F32(src_ptr + 1 * pack_tile, s1); \
MS_STQ_F32(src_ptr + 2 * pack_tile, s2); \
MS_STQ_F32(src_ptr + 3 * pack_tile, s3); \
MS_STQ_F32(src_ptr + 4 * pack_tile, s4); \
MS_STQ_F32(src_ptr + 5 * pack_tile, s5); \
MS_STQ_F32(src_ptr + 6 * pack_tile, s6); \
MS_STQ_F32(src_ptr + 7 * pack_tile, s7); \
MS_STQ_F32(src_ptr + 8 * pack_tile, s8); \
MS_STQ_F32(src_ptr + 9 * pack_tile, s9); \
MS_STQ_F32(src_ptr + 10 * pack_tile, s10); \
MS_STQ_F32(src_ptr + 11 * pack_tile, s11);
InputTransFunc GetInputTransFunc(int input_unit);
#ifdef ENABLE_ARM64
InputTransStepFunc GetInputTransStepFunc(int input_unit);
InputTransPackFunc GetInputTransPackFunc(int input_unit);
#endif
void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
void InputTransform4x4Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step);
void InputTransform4x4Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
void InputTransform6x6Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step);
void InputTransform6x6Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
void InputTransform8x8Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step);
void InputTransform8x8Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
OutputTransFunc GetOutputTransFunc(int input_unit, int output_unit, ActType act_type);
#define Store4Data \

View File

@ -172,10 +172,10 @@ int ConvolutionWinogradFP32Coder::InitWeightBias() {
}
int ConvolutionWinogradFP32Coder::ConfigInputOutput() {
in_func_ = GetInputTransFunc(input_unit_);
MS_CHECK_TRUE(!in_func_.empty(), "Get input_trans_func failed.");
out_func_ = GetOutputTransFunc(input_unit_, output_unit_, conv_param_->act_type_);
MS_CHECK_TRUE(!out_func_.empty(), "Get output_trans_func_ failed.");
trans_func_str_.in_func_ = GetInputTransFunc(input_unit_);
MS_CHECK_TRUE(!trans_func_str_.in_func_.empty(), "Get input_trans_func failed.");
trans_func_str_.out_func_ = GetOutputTransFunc(input_unit_, output_unit_, conv_param_->act_type_);
MS_CHECK_TRUE(!trans_func_str_.out_func_.empty(), "Get output_trans_func_ failed.");
return RET_OK;
}
@ -269,9 +269,10 @@ int ConvolutionWinogradFP32Coder::DoCode(CoderContext *const context) {
<< allocator_->GetRuntimeAddr(gemm_out_) << ", " << allocator_->GetRuntimeAddr(tmp_data_) << ", "
<< allocator_->GetRuntimeAddr(col_buffer_) << "};\n";
code.CodeStruct("conv_parameter", *conv_param_);
code.CodeStruct("trans_func", trans_func_str_);
// code operator func
code.CodeFunction("ConvWinogardFp32", input_tensor_, trans_weight_, new_bias_, output_tensor_,
"tmp_buffer_address_list", kDefaultTaskId, "&conv_parameter", in_func_, out_func_);
"tmp_buffer_address_list", kDefaultTaskId, "&conv_parameter", "trans_func");
context->AppendCode(code.str());
return RET_OK;
}

View File

@ -22,6 +22,7 @@
#include <vector>
#include "coder/opcoders/base/conv2d_base_coder.h"
#include "nnacl/conv_parameter.h"
#include "wrapper/fp32/conv_winograd_fp32_wrapper.h"
namespace mindspore::lite::micro::nnacl {
class ConvolutionWinogradFP32Coder : public Conv2DBaseCoder {
@ -68,8 +69,7 @@ class ConvolutionWinogradFP32Coder : public Conv2DBaseCoder {
float *gemm_out_{nullptr};
float *col_buffer_{nullptr};
std::string in_func_;
std::string out_func_;
TransFuncStr trans_func_str_;
};
} // namespace mindspore::lite::micro::nnacl
#endif // MINDSPORE_LITE_MICRO_CODER_OPCODERS_FP32_CONVOLUTION_WINOGRAD_FP32_CODER_H_

View File

@ -157,4 +157,8 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const SpliceWrappe
splice_param.src_to_dst_row_offset);
}
void NNaclFp32Serializer::CodeStruct(const std::string &name, const TransFuncStr trans_func_str) {
CodeBaseStruct("TransFuncList", name, trans_func_str.in_func_, nullptr, nullptr, trans_func_str.out_func_);
}
} // namespace mindspore::lite::micro::nnacl

View File

@ -36,6 +36,7 @@
#include "nnacl/fp32/strided_slice_fp32.h"
#include "wrapper/fp32/arithmetic_fp32_wrapper.h"
#include "wrapper/base/affine_wrapper.h"
#include "wrapper/fp32/conv_winograd_fp32_wrapper.h"
namespace mindspore::lite::micro::nnacl {
@ -60,6 +61,7 @@ class NNaclFp32Serializer : public Serializer {
void CodeStruct(const std::string &name, const StridedSliceParameter &strided_slice_parameter);
void CodeStruct(const std::string &name, const ArithmeticWrapperInfo &arithmetic_wrapper_info);
void CodeStruct(const std::string &name, const SpliceWrapperParam &splice_param);
void CodeStruct(const std::string &name, const TransFuncStr trans_func_str);
};
} // namespace mindspore::lite::micro::nnacl

View File

@ -0,0 +1,30 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_MICRO_CODER_WRAPPER_FP32_CONV_WINOGRAD_FP32_WRAPPER_H_
#define MINDSPORE_LITE_MICRO_CODER_WRAPPER_FP32_CONV_WINOGRAD_FP32_WRAPPER_H_
#include <string>
#ifdef __cplusplus
extern "C" {
#endif
typedef struct TransFuncStr {
std::string in_func_;
std::string out_func_;
} TransFuncStr;
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_MICRO_CODER_WRAPPER_FP32_CONV_WINOGRAD_FP32_WRAPPER_H_

View File

@ -119,21 +119,40 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() {
return RET_ERROR;
}
opt_input_trans_ = reinterpret_cast<float16_t *>(
ctx_->allocator->Malloc(thread_count_ * row_tile_ * input_unit_ * input_unit_ *
UP_ROUND(conv_param_->input_channel_, C8NUM) * sizeof(float16_t)));
if (opt_input_trans_ == nullptr) {
MS_LOG(ERROR) << "malloc opt_input_trans_ failed.";
return RET_ERROR;
}
tmp_buffer_address_list_[0] = trans_input_;
tmp_buffer_address_list_[1] = gemm_out_;
tmp_buffer_address_list_[2] = tmp_data_;
tmp_buffer_address_list_[3] = col_buffer_;
tmp_buffer_address_list_[4] = opt_input_trans_;
return RET_OK;
}
int ConvolutionWinogradFP16CPUKernel::ConfigInputOutput() {
in_func_ = GetInputTransFp16Func(input_unit_);
if (in_func_ == nullptr) {
trans_func_.in_func_ = GetInputTransFp16Func(input_unit_);
if (trans_func_.in_func_ == nullptr) {
MS_LOG(ERROR) << "in_func_ is null.";
return RET_ERROR;
}
out_func_ = GetOutputTransFp16Func(input_unit_, output_unit_, conv_param_->act_type_);
if (out_func_ == nullptr) {
#ifdef ENABLE_ARM64
trans_func_.in_step_func_ = GetInputTransStepFp16Func(input_unit_);
if (trans_func_.in_step_func_ == nullptr) {
MS_LOG(DEBUG) << "in_step_func_ is null.";
}
trans_func_.in_pack_func_ = GetInputTransPackFp16Func(input_unit_);
if (trans_func_.in_pack_func_ == nullptr) {
MS_LOG(DEBUG) << "in_pack_func_ is null.";
}
#endif
trans_func_.out_func_ = GetOutputTransFp16Func(input_unit_, output_unit_, conv_param_->act_type_);
if (trans_func_.out_func_ == nullptr) {
MS_LOG(ERROR) << "out_func_ is null.";
return RET_ERROR;
}
@ -219,7 +238,7 @@ int ConvolutionWinogradFP16CPUKernel::RunImpl(int task_id) {
}
ConvWinogardFp16(input_ptr, reinterpret_cast<float16_t *>(packed_weight_),
reinterpret_cast<const float16_t *>(bias_data_), output_ptr, tmp_buffer_address_list_, task_id,
conv_param_, in_func_, out_func_);
conv_param_, trans_func_);
return RET_OK;
}

View File

@ -65,6 +65,10 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseCPUKernel {
ctx_->allocator->Free(col_buffer_);
col_buffer_ = nullptr;
}
if (opt_input_trans_ != nullptr) {
ctx_->allocator->Free(opt_input_trans_);
opt_input_trans_ = nullptr;
}
}
int FilterWeight();
int kernel_unit_ = 0;
@ -74,11 +78,11 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseCPUKernel {
float16_t *trans_input_ = nullptr;
float16_t *gemm_out_ = nullptr;
float16_t *col_buffer_ = nullptr;
float16_t *opt_input_trans_ = nullptr;
float matrix_g_[64];
float matrix_gt_[64];
TmpBufferAddressFp16 tmp_buffer_address_list_[4] = {0};
InputTransFp16Func in_func_ = nullptr;
OutputTransFp16Func out_func_ = nullptr;
TmpBufferAddressFp16 tmp_buffer_address_list_[5] = {0};
TransFp16FuncList trans_func_;
int col_tile_ = 0;
int row_tile_ = 0;
};

View File

@ -69,21 +69,40 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() {
return RET_ERROR;
}
opt_input_trans_ = reinterpret_cast<float *>(
ctx_->allocator->Malloc(thread_count_ * tile_num_ * input_unit_ * input_unit_ *
UP_ROUND(conv_param_->input_channel_, tmp_data_tile_) * sizeof(float)));
if (opt_input_trans_ == nullptr) {
MS_LOG(ERROR) << "malloc opt_input_trans_ failed.";
return RET_ERROR;
}
tmp_buffer_address_list_[0] = trans_input_;
tmp_buffer_address_list_[1] = gemm_out_;
tmp_buffer_address_list_[2] = tmp_data_;
tmp_buffer_address_list_[3] = col_buffer_;
tmp_buffer_address_list_[4] = opt_input_trans_;
return RET_OK;
}
int ConvolutionWinogradCPUKernel::ConfigInputOutput() {
in_func_ = GetInputTransFunc(input_unit_);
if (in_func_ == nullptr) {
trans_func_.in_func_ = GetInputTransFunc(input_unit_);
if (trans_func_.in_func_ == nullptr) {
MS_LOG(ERROR) << "in_func_ is null.";
return RET_ERROR;
}
out_func_ = GetOutputTransFunc(input_unit_, output_unit_, conv_param_->act_type_);
if (out_func_ == nullptr) {
#ifdef ENABLE_ARM64
trans_func_.in_step_func_ = GetInputTransStepFunc(input_unit_);
if (trans_func_.in_step_func_ == nullptr) {
MS_LOG(DEBUG) << "in_step_func_ is null.";
}
trans_func_.in_pack_func_ = GetInputTransPackFunc(input_unit_);
if (trans_func_.in_pack_func_ == nullptr) {
MS_LOG(DEBUG) << "in_pack_func_ is null.";
}
#endif
trans_func_.out_func_ = GetOutputTransFunc(input_unit_, output_unit_, conv_param_->act_type_);
if (trans_func_.out_func_ == nullptr) {
MS_LOG(ERROR) << "out_func_ is null.";
return RET_ERROR;
}
@ -152,7 +171,7 @@ int ConvolutionWinogradCPUKernel::RunImpl(int task_id) {
CHECK_NULL_RETURN(output_data);
ConvWinogardFp32(ori_input_data, reinterpret_cast<float *>(packed_weight_),
reinterpret_cast<const float *>(bias_data_), output_data, tmp_buffer_address_list_, task_id,
conv_param_, in_func_, out_func_);
conv_param_, trans_func_);
return RET_OK;
}

View File

@ -62,6 +62,10 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel {
ctx_->allocator->Free(col_buffer_);
col_buffer_ = nullptr;
}
if (opt_input_trans_ != nullptr) {
ctx_->allocator->Free(opt_input_trans_);
opt_input_trans_ = nullptr;
}
}
int kernel_unit_{0};
int input_unit_{0};
@ -73,11 +77,11 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel {
float *trans_input_ = nullptr;
float *gemm_out_ = nullptr;
float *col_buffer_ = nullptr;
float *opt_input_trans_ = nullptr;
float matrix_g_[64];
float matrix_gt_[64];
TmpBufferAddress tmp_buffer_address_list_[4] = {nullptr};
InputTransFunc in_func_ = nullptr;
OutputTransFunc out_func_ = nullptr;
TmpBufferAddress tmp_buffer_address_list_[5] = {nullptr};
TransFuncList trans_func_;
};
} // namespace mindspore::kernel

View File

@ -3,13 +3,13 @@
# [second column]:accuracy limit for float16 in arm64 device
hdc_age_medium 5.9
beard 2
emotion 60
gender_res_large_deploy 0.1
emotion 216
gender_res_large_deploy 2
glasses 4
hat 2.5
ml_bank_detect_0312_tmp 20
ml_face_div_parsing 8
ml_hardware_eyeclose 0.1
ml_hardware_eyeclose 0.5
ml_ocr_detect_20200305 10
Mnet6_0312_extract_pay 15
pose_3d 90
@ -37,7 +37,7 @@ ml_ocr_sfz_add_final_0325 0.1
ml_hardware_pose 2
ml_bank_recog 0.1
2012_ATLANTA_10class_20190131_v4.0 12
mnet 12
mnet 13
recognition 10.8
ml_face_landmark 1
model_hebing_3branch 40
@ -71,31 +71,31 @@ ml_location_scene_division 8
ml_tabel_recog 0.1
ml_text_division 12
# Further analysis in the future to model ml_video_edit_Mnet
ml_video_edit_Mnet 11.5
ml_video_edit_Mnet 15.5
ml_video_edit_hairSeg_have_imageProcessLayer_interpTo145 0.5
hdc_contour_pose_128 0.5
hdc_emotion 0.5
hdc_fivembnet 1
hdc_isface 0.5
hdc_mobilenetface 11.5 # small output causes big bias
hdc_retinaface 14
hdc_retinaface 15
hdc_resnet 7
ml_video_edit_detect_20211111 2.5
ml_video_edit_detect_20211111 3
ml_video_edit_hairSeg_have_imageProcessLayer_interpTo145_20210121 0.5
ml_video_edit_have_imageProcessLayer_interpTo145_20201015 0.5
ml_video_edit_MnetN367_extract_1010_pay 1
ml_video_edit_person_divison_pic 0.5
ml_video_edit_reid 1
ml_video_edit_v10_best_model_nomean_20200723 5.1
ml_video_edit_v10_best_model_nomean_20200723 6
ml_video_edit_img_segment 3
ml_video_edit_video_segment_gauss_adaptis_part1 5
# When the input range is [-1,1], the precision is poor, and the output value is very small (10e-5). If the input range is adjusted to [0,255], the precision will decrease to 15.5415%, and the rest is cumulative error.
ml_handpose 175
ml_handpose 177
hdc_Face_Aesthetic_MTI_Aesthetic 0.5
ml_face_compare 8.7
ml_face_tracking 2.5
ml_face_beard 0.6
ml_face_age 3.7
ml_face_beard 1
ml_face_age 4
ml_face_pose 1
ml_face_isface 0.5
ml_face_glasses 3.4
@ -108,13 +108,13 @@ ml_Hand_deploy 4
ml_hand_3d_detection 12
ml_hand_3d_regression 5.4
# ml_ARengine23_bodypose: The difference of output node divided by a very small value leads to a large error
ml_ARengine23_bodypose 56
ml_ARengine23_bodypose 57
ml_ocr_bank_card_detection_inception_tmp 20
ml_ocr_bank_card_recognition_fcny 0.5
hiai_cv_aestheticsEngineModel_osp 1.6
hiai_cv_aestheticsEngineModel_osp 3.5
ml_face_hat 2.2
bank_card_recognition_fcny 17
bank_card_detection_inception_tmp 12
bank_card_recognition_fcny 19
bank_card_detection_inception_tmp 13.5
ml_ocr_identify_card_fcny 0.5
ml_ocr_identify_card_detect_tmp 2
identify_card_detect_tmp 0.5
@ -123,18 +123,18 @@ ml_2012_ocr_rec_caffe 0.5
ml_lable_model_hebing_device 3
ml_face_sex 0.7
# ml_face_mnet: The precision problem caused by cumulative error.
ml_face_mnet 12
ml_face_mnet 13
ml_segmentation_atlanta_1 0.5
bolt_deploy_color-server 0.5
ml_face_emotion 0.5
hdc_ocr_recog_horizontal 0.5
# The outputs of two Heatmap_depth models have small value
ml_Heatmap_depth_240180;2 10
ml_Heatmap_depth_240180;2 14.5
ml_Heatmap_depth_180240;2 7
ml_video_edit_hair_dyeing_segmodel_v3 0.5
ml_video_edit_hairline_segmentation;3 1.5
ml_video_edit_seg_320 0.5
hiai_machine_vision_jfr_newmodel_2730_houduan_yolo 5
hiai_machine_vision_mobileNet101_nosoftce_mobilenet_resnet 7.5
ml_video_edit_person_divison_video;2 38
ml_video_edit_person_divison_video;2 42
ml_video_edit_hair_dyeing_segmodel_20211119 0.5

View File

@ -3,7 +3,7 @@
# [second column]:accuracy limit for float16 in arm64 device
mtk_detect-mbv2-shortcut-400-400-simplified.onnx 4
mtk_face_features_v3.onnx 20
emotion-ferplus-8.onnx 1
emotion-ferplus-8.onnx 1.5
#rcnn-ilsvrc13-9.onnx 0.1
efficientnet-lite4-11.onnx 2
mobilenetv2-7.onnx 8
@ -27,7 +27,7 @@ mnist-8.onnx 10
crnn_lite_lstm_v2.onnx;1;32,32,32,1 0.3
#psenet_lite_mbv2.onnx;1;1,32,32,3 0.6
#occasionally aborted
super-resolution-10.onnx;1;1,224,224,1 4.5
super-resolution-10.onnx;1;1,224,224,1 5
tinyyolov2-8.onnx;1;1,416,416,3 5.5
#ml_2012_ocr_cn.onnx -1
#ml_2012_ocr_cn_noLSTM.onnx 1
@ -52,10 +52,10 @@ ml_video_edit_style_transfer_autoportrait.onnx 2
ml_video_edit_style_transfer_candy.onnx 2
ml_video_edit_style_transfer_gongnongbing.onnx 2
ml_video_edit_style_transfer_starry.onnx 2
hdc_Face_Landmark5_MTI_Aesthetic.onnx 0.5
hdc_Face_Landmark5_MTI_Aesthetic.onnx 1
hdc_Image_Aesthetic_MTI_Aesthetic.onnx 0.5
hdc_resnet_1w_class.onnx 6
gts_text_detection.onnx;1;1,224,224,3 10
gts_text_detection.onnx;1;1,224,224,3 11
hdc_Face_Emotion_MTI_Aesthetic.onnx 144
ml_video_edit_imitate_filter.onnx 120
ml_facedetector.onnx 6
@ -71,7 +71,7 @@ mtk_emotions-d2012-75.onnx 6
mtk_detect-mbv1-shortcut-400-400.onnx 0.5
mtk_detect-mbv2-shortcut-400-400.onnx 0.5
mtk_detect_mbv1_640_480.onnx 0.5
mtk_detect-deeper-halfdeeper-mbv1-shortcut-400-400_nopostprocess_simplified_onnx.onnx 2
mtk_detect-deeper-halfdeeper-mbv1-shortcut-400-400_nopostprocess_simplified_onnx.onnx 2.5
mtk_detect-mbv1-shortcut-400-400_nopostprocess_simplified_onnx.onnx 6.5
mtk_detect-deeper-halfdeeper-mbv1-lastearlySSD-shortcut-400-400_nopostprocess_simplified_onnx.onnx 2.5
mtk_detect_mbv1_640_480_nopostprocess_simplified_onnx.onnx;1;1,480,640,3 2
@ -87,7 +87,7 @@ Q888_iris_detect.onnx 0.5
ssd_mobilenet_v1_10.onnx;1;1,383,640,3 0.5
# The output from a conv in the later part contains many minus values, the following leakyRelu makes them become very
# close to 0 (-e^-4). The fp16 precision lost a lot in this case and it affects the following computation.
Harmony_Voiceprint.onnx;1;1,200,40,1 21.5 # small output causes big bias
Harmony_Voiceprint.onnx;1;1,200,40,1 68 # small output causes big bias
# A matmul op in the later part produces overflowed output values (>65504).
#ml_video_edit_art_generate_20210513.onnx nan
# bn_fusion causes a big bias(maybe random), need to debug later: The original bias is 2.1
@ -104,11 +104,11 @@ ml_video_edit_makeup_mobilenetv203.onnx 4
# The input of ml_video_edit_hair_dyeing_migrate_v2.onnx should be between [0, 1]
ml_video_edit_hair_dyeing_migrate_v2.onnx;4 2.5
Q888_CV_face_recognition_self.onnx 3.9
ml_video_edit_hair_dyeing_migrate_v2_fix.onnx;4 3
ml_video_edit_hair_dyeing_migrate_v2_fix.onnx;4 3.5
ml_intelligent_cockpit_model.onnx;3;1,32:1,32:1,32 3.8
CloudBU_FSRCNN_RTC_8ch_3450_QP9.onnx;1;1,225,225,3 1.5
CloudBU_rfdn_rtc_x2_ver2_13.onnx;1;1,225,225,3 1.5
CloudBU_rfdn_rtc_x2_ver2_3450.onnx;1;1,225,225,3 3.0
CloudBU_rfdn_rtc_x2_ver2_3450.onnx;1;1,225,225,3 4
ml_motion_capture_nanodet_m_0.5x_people_0928_sim.onnx 8
ml_motion_capture_smpl_0916.onnx;3
ml_motion_capture_spin_mobile_mv3_v3_57mm_sim.onnx;5 18
@ -116,7 +116,7 @@ ml_video_edit_dimming_tech_model_345000_color.onnx;2 2
Ireland_ulfgf.onnx;1;1,240,320,3
Ireland_gaze_corrector.onnx;3 15
Ireland_face_detector.onnx 2
Ireland_gaze_estimator_ng.onnx 6
Ireland_gaze_estimator_ng.onnx 8
carbu_intelligent_cockpit_fasttext_best.onnx 0.5
ml_video_edit_shot_selection_yolox_nano_coco_reduced.onnx 3
ml_video_edit_shot_selection_face_emotion.onnx 0.7

View File

@ -18,7 +18,7 @@ hiai_ssd_mobilenetv2_object.pb 15
hiai_humanDetection.pb 3.5
hiai_PoseEstimation_Pcm.pb 0.5
# The last layer has a very small value, which leads to a large error
hiai_cn_recognize_modify_padv2.pb;1;1,32,512,1 27
hiai_cn_recognize_modify_padv2.pb;1;1,32,512,1 37
hiai_model_normalize_object_scene_ps_20200519.pb;1;1,224,224,3 17.1
# The output of mtk_model_ckpt.pb has small value
mtk_model_ckpt.pb 19.5
@ -33,14 +33,14 @@ mtk_face_features_v1.pb 26
model_normalize_object_scene_ps_20200519.pb;1;1,224,224,3 10
hiai_AADB_HADB_MBV2_model.pb;1;1,224,224,3 6
hiai_frozen_inference_graph.pb 12
hiai_lm_inference_graph.pb 1.2
hiai_lm_inference_graph.pb 1.5
hiai_ghostnet.pb 0.9
hiai_face_model_npu.pb 0.5
hiai_cv_focusShootOCRModel_02.pb 10.5
hiai_cv_focusShootOCRModel_02.pb 12.5
hiai_label_and_video.pb;1;1,224,224,3 23
hiai_dress_detect.pb;1;1,960,960,3 1.5
hiai_iMaxDN_RGB.pb 0.5
hiai_iMaxSR_RGB.pb 3.5
hiai_iMaxSR_RGB.pb 5
hiai_ctpn_feature_map.pb 6.5
hiai_cpu_face_gazing.pb 0.5
hiai_cpu_face_emotion.pb 2.2
@ -49,7 +49,7 @@ Q_dila-small-mix-full-fineturn-390000-nopixel-nosigmoid.pb 1.5
# The input of Q_crnn_ori_75w_slim model is between 0-255, but its outputs has small values (e-6).
Q_crnn_ori_75w_slim_norm.pb 37
# The output of Q_crnn_ori_v2 model has small values (e-4).
Q_crnn_ori_v2_405001_notrans_nopre.pb 24
Q_crnn_ori_v2_405001_notrans_nopre.pb 33
# The input of hiai_latin models are between 0-255
hiai_latin_ocr.pb 4
hiai_latin_ocr_1.pb 3.5
@ -68,7 +68,7 @@ ml_vision_guide_detection2.pb;1;1,320,320,1 1
ml_tts_encoder.pb;4;1,44:1:1:1 9
# encoder_0111_control_flow.pb is same as ml_tts_encoder_control_flow.pb
#encoder_0111_control_flow.pb;4;1:1,44:1:1 10
ml_video_edit_video_segment_gauss_adaptis_part2.pb;2 12.1
ml_video_edit_video_segment_gauss_adaptis_part2.pb;2 16
ml_video_edit_img_segment_adaptise.pb;2 40
ml_video_edit_oneclick_adaptis.pb;3 6
#decoder_step_201217.pb is the same model as ml_tts_decoder.pb.

View File

@ -2,7 +2,7 @@
# content after ";" can be omitted.
# [second column]:accuracy limit for float16 in arm64 device
hiai_model_0909_kd_rot_ps_softmax.tflite 10
hiai_chinese_english_recognize_model_float32.tflite 13
hiai_chinese_english_recognize_model_float32.tflite 13.5
hiai_bigmodel_ghost_2_1_no_normalized_no_trans_tflite.tflite 10
hiai_bigmodel_ghost_5_1_no_normalized_no_trans_tflite.tflite 10
hiai_cn_recognize_modify_padv2.tflite 14
@ -64,7 +64,7 @@ inception_resnet_v2.tflite 10
ml_ocr_latin.tflite 15
hiai_PoseEstimation_Pcm.tflite 15
hiai_ssd_mobilenetv2_object.tflite 60
hiai_cv_focusShootOCRModel_02.tflite 13
hiai_cv_focusShootOCRModel_02.tflite 13.5
hiai_cv_poseEstimation.tflite 190
inception_v4.tflite 10
mtk_model_normalize_object_scene_ps_20200519_f16.tflite 10
@ -129,8 +129,8 @@ mtk_pose.tflite 2
mtk_model_emotions_0727_nosoftmax.tflite 2
mtk_model_normalize_object_scene_ps_20200826_f32_no_softmax.tflite 22
mtk_276landmark_0913.tflite 16
mtk_face_recognition.tflite 8
mtk_convert_model.tflite 5.3
mtk_face_recognition.tflite 11
mtk_convert_model.tflite 7.5
smartreply.tflite 0.1
mindspore_text_classification_tflite.tflite 9.2 # small output causes big bias
#ml_location.tflite 0.1
@ -176,7 +176,7 @@ Q_convert.tflite 12
# the input of Q_crnn_ori_75w_slim model is between 0-255, but its outputs has small values (e-6).
Q_crnn_ori_75w_slim_norm_pb2tflite.tflite 29
# the output of Q_crnn_ori_v2 model has small values (e-4).
Q_crnn_ori_v2_405001_notrans_nopre_pb2tflite.tflite 36
Q_crnn_ori_v2_405001_notrans_nopre_pb2tflite.tflite 42
# the inputs of two Q_crnn_screen_slim400w models are between 0-255, but their outputs have small values (e-7).
Q_crnn_screen_slim400w_more_20w_pb2tflite.tflite 71
Q_dila-small-mix-full-fineturn-390000-nopixel-nosigmoid_tflite.tflite 1.5
@ -202,14 +202,14 @@ Q888_model_normalize_object_scene_ps_20200826_f32_no_softmax.tflite 2
# input data: -1~1
Q888_face_emo_dress_mv3_orderd.tflite 2.5
Q_iMaxDN_RGB_385_p_RGB_RGB_pb2tflite.tflite 1
Q_iMaxSR_RGB_385_p_pb2tflite.tflite 5
Q_iMaxSR_RGB_385_p_pb2tflite.tflite 5.5
bloom_new_detect.tflite 3.5
bloom_model_age_gender.tflite 0.5
bloom_isface.tflite 0.5
# The output values of conv layers range from -e±5 to e±5, which almost reaches the representation limit of fp16. In
# this range, the fp16 data will has big bias. And the accumulation of this bias lowers the final precision.
hiai_object_detect_814.tflite 14
ml_video_edit_video_segment_gauss_adaptis_part2_pb2tflite.tflite;2 12.1
ml_video_edit_video_segment_gauss_adaptis_part2_pb2tflite.tflite;2 16
ml_video_edit_img_segment_adaptise_pb2tflite.tflite;2 0.5
hdc_tb_cn_neg.tflite;3 295
# The input of hiai_cv_labelDetectorModel_v3.tflite is between 0-255.