!11977 [MS][LITE][CPU]change the parallel strategy for fp16 winograd

From: @fuzhiye
Reviewed-by: @zhang_xue_tong,@hangangqiang
Signed-off-by: @zhang_xue_tong
This commit is contained in:
mindspore-ci-bot 2021-02-05 10:30:23 +08:00 committed by Gitee
commit ca3f916c1e
12 changed files with 377 additions and 332 deletions

View File

@ -160,7 +160,9 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa
int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_);
int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_);
int output_count = out_w_block * out_h_block;
int output_tile_count = UP_DIV(output_count, tile_num);
int per_thread_num = UP_DIV(output_count, conv_param->thread_num_);
int real_tile = per_thread_num < tile_num ? per_thread_num : tile_num;
int output_tile_count = UP_DIV(output_count, real_tile);
int oc8 = UP_DIV(conv_param->output_channel_, C8NUM);
int input_unit_square = conv_param->input_unit_ * conv_param->input_unit_;
@ -178,9 +180,12 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa
int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_;
int out_batch_offset = b * conv_param->output_channel_ * conv_param->output_h_ * conv_param->output_w_;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) {
int out_tile_index = thread_id * tile_num;
int cal_num = output_count - thread_id * tile_num;
cal_num = cal_num > tile_num ? tile_num : cal_num;
int out_tile_index = thread_id * real_tile;
int cal_num = output_count - thread_id * real_tile;
cal_num = cal_num > real_tile ? real_tile : cal_num;
if (cal_num <= 0) {
return;
}
WinogradInputTransformFp16(input_data + in_batch_offset, trans_input + task_id * trans_input_offset,
tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param,
in_func);
@ -189,7 +194,7 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa
float16_t *dst_ptr = gemm_out + task_id * gemm_out_offset;
float16_t *tmp_col_ptr = col_buffer + task_id * col_buffer_offset;
for (int i = 0; i < input_unit_square; ++i) {
RowMajor2Col16MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, tile_num, in_channel);
RowMajor2Col16MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel);
MatMulFp16(tmp_col_ptr, trans_weight + i * in_channel * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, in_channel,
cal_num, oc8 * C8NUM, input_unit_square, OutType_TileC8);
}

View File

@ -16,201 +16,212 @@
#include "nnacl/fp16/matmul_fp16.h"
void ColMajor2Row8MajorFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col, bool src_float16) {
static void Col2Row8SrcFromFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col) {
int row_c8 = row / C8NUM * C8NUM;
int col_c8 = col / C8NUM * C8NUM;
const float16_t *src = (const float16_t *)src_ptr;
int ci = 0;
for (; ci < col_c8; ci += C8NUM) {
int ri = 0;
for (; ri < row_c8; ri += C8NUM) {
const float16_t *src_ptr1 = src + ci * row + ri;
float16_t *dst_ptr1 = dst_ptr + ci * row + ri * C8NUM;
#ifdef ENABLE_ARM64
size_t strid_row = row * 2;
asm volatile(
"mov x10, %[src_ptr1]\n"
"mov x11, %[dst_ptr1]\n"
"mov x12, %[strid_row]\n"
"ld1 {v0.8h}, [x10], x12\n"
"ld1 {v1.8h}, [x10], x12\n"
"ld1 {v2.8h}, [x10], x12\n"
"ld1 {v3.8h}, [x10], x12\n"
"ld1 {v4.8h}, [x10], x12\n"
"ld1 {v5.8h}, [x10], x12\n"
"ld1 {v6.8h}, [x10], x12\n"
"ld1 {v7.8h}, [x10], x12\n"
"zip1 v8.8h, v0.8h, v1.8h\n"
"zip1 v9.8h, v2.8h, v3.8h\n"
"zip1 v10.8h, v4.8h, v5.8h\n"
"zip1 v11.8h, v6.8h, v7.8h\n"
"trn1 v12.4s, v8.4s, v9.4s\n"
"trn1 v14.4s, v10.4s, v11.4s\n"
"trn2 v13.4s, v8.4s, v9.4s\n"
"trn2 v15.4s, v10.4s, v11.4s\n"
"trn1 v16.2d, v12.2d, v14.2d\n"
"trn2 v18.2d, v12.2d, v14.2d\n"
"trn1 v17.2d, v13.2d, v15.2d\n"
"trn2 v19.2d, v13.2d, v15.2d\n"
"zip2 v8.8h, v0.8h, v1.8h\n"
"zip2 v9.8h, v2.8h, v3.8h\n"
"zip2 v10.8h, v4.8h, v5.8h\n"
"zip2 v11.8h, v6.8h, v7.8h\n"
"trn1 v12.4s, v8.4s, v9.4s\n"
"trn1 v14.4s, v10.4s, v11.4s\n"
"trn2 v13.4s, v8.4s, v9.4s\n"
"trn2 v15.4s, v10.4s, v11.4s\n"
"trn1 v20.2d, v12.2d, v14.2d\n"
"trn2 v22.2d, v12.2d, v14.2d\n"
"trn1 v21.2d, v13.2d, v15.2d\n"
"trn2 v23.2d, v13.2d, v15.2d\n"
"st1 {v16.8h}, [x11], #16\n"
"st1 {v17.8h}, [x11], #16\n"
"st1 {v18.8h}, [x11], #16\n"
"st1 {v19.8h}, [x11], #16\n"
"st1 {v20.8h}, [x11], #16\n"
"st1 {v21.8h}, [x11], #16\n"
"st1 {v22.8h}, [x11], #16\n"
"st1 {v23.8h}, [x11], #16\n"
:
: [ dst_ptr1 ] "r"(dst_ptr1), [ src_ptr1 ] "r"(src_ptr1), [ strid_row ] "r"(strid_row)
: "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13",
"v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
#else
for (int tr = 0; tr < C8NUM; ++tr) {
for (int tc = 0; tc < C8NUM; ++tc) {
dst_ptr1[tr * C8NUM + tc] = src_ptr1[tc * row + tr];
}
}
#endif
}
for (; ri < row; ++ri) {
const float16_t *src_ptr1 = src + ci * row;
float16_t *dst_ptr1 = dst_ptr + ci * row;
for (int tc = 0; tc < C8NUM; ++tc) {
dst_ptr1[ri * C8NUM + tc] = src_ptr1[tc * row + ri];
}
}
}
for (int r = 0; r < row; r++) {
for (int tc = ci; tc < col; tc++) {
int cd8 = tc / C8NUM;
int cm8 = tc % C8NUM;
dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = src[tc * row + r];
}
}
}
static void Col2Row8SrcFromFp32(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col) {
int row_c8 = row / C8NUM * C8NUM;
int col_c8 = col / C8NUM * C8NUM;
int ci = 0;
const float *src = (const float *)src_ptr;
for (; ci < col_c8; ci += C8NUM) {
int ri = 0;
for (; ri < row_c8; ri += C8NUM) {
const float *src_ptr1 = src + ci * row + ri;
float16_t *dst_ptr1 = dst_ptr + ci * row + ri * C8NUM;
#ifdef ENABLE_ARM64
size_t strid_row = row * 4;
asm volatile(
"mov x10, %[src_ptr1]\n"
"mov x11, %[dst_ptr1]\n"
"mov x12, %[strid_row]\n"
"ld1 {v8.4s, v9.4s}, [x10], x12\n"
"ld1 {v10.4s, v11.4s}, [x10], x12\n"
"ld1 {v12.4s, v13.4s}, [x10], x12\n"
"ld1 {v14.4s, v15.4s}, [x10], x12\n"
"ld1 {v16.4s, v17.4s}, [x10], x12\n"
"ld1 {v18.4s, v19.4s}, [x10], x12\n"
"ld1 {v20.4s, v21.4s}, [x10], x12\n"
"ld1 {v22.4s, v23.4s}, [x10], x12\n"
"fcvtn v0.4h, v8.4s\n"
"fcvtn2 v0.8h, v9.4s\n"
"fcvtn v1.4h, v10.4s\n"
"fcvtn2 v1.8h, v11.4s\n"
"fcvtn v2.4h, v12.4s\n"
"fcvtn2 v2.8h, v13.4s\n"
"fcvtn v3.4h, v14.4s\n"
"fcvtn2 v3.8h, v15.4s\n"
"fcvtn v4.4h, v16.4s\n"
"fcvtn2 v4.8h, v17.4s\n"
"fcvtn v5.4h, v18.4s\n"
"fcvtn2 v5.8h, v19.4s\n"
"fcvtn v6.4h, v20.4s\n"
"fcvtn2 v6.8h, v21.4s\n"
"fcvtn v7.4h, v22.4s\n"
"fcvtn2 v7.8h, v23.4s\n"
"zip1 v8.8h, v0.8h, v1.8h\n"
"zip1 v9.8h, v2.8h, v3.8h\n"
"zip1 v10.8h, v4.8h, v5.8h\n"
"zip1 v11.8h, v6.8h, v7.8h\n"
"trn1 v12.4s, v8.4s, v9.4s\n"
"trn1 v14.4s, v10.4s, v11.4s\n"
"trn2 v13.4s, v8.4s, v9.4s\n"
"trn2 v15.4s, v10.4s, v11.4s\n"
"trn1 v16.2d, v12.2d, v14.2d\n"
"trn2 v18.2d, v12.2d, v14.2d\n"
"trn1 v17.2d, v13.2d, v15.2d\n"
"trn2 v19.2d, v13.2d, v15.2d\n"
"zip2 v8.8h, v0.8h, v1.8h\n"
"zip2 v9.8h, v2.8h, v3.8h\n"
"zip2 v10.8h, v4.8h, v5.8h\n"
"zip2 v11.8h, v6.8h, v7.8h\n"
"trn1 v12.4s, v8.4s, v9.4s\n"
"trn1 v14.4s, v10.4s, v11.4s\n"
"trn2 v13.4s, v8.4s, v9.4s\n"
"trn2 v15.4s, v10.4s, v11.4s\n"
"trn1 v20.2d, v12.2d, v14.2d\n"
"trn2 v22.2d, v12.2d, v14.2d\n"
"trn1 v21.2d, v13.2d, v15.2d\n"
"trn2 v23.2d, v13.2d, v15.2d\n"
"st1 {v16.8h}, [x11], #16\n"
"st1 {v17.8h}, [x11], #16\n"
"st1 {v18.8h}, [x11], #16\n"
"st1 {v19.8h}, [x11], #16\n"
"st1 {v20.8h}, [x11], #16\n"
"st1 {v21.8h}, [x11], #16\n"
"st1 {v22.8h}, [x11], #16\n"
"st1 {v23.8h}, [x11], #16\n"
:
: [ dst_ptr1 ] "r"(dst_ptr1), [ src_ptr1 ] "r"(src_ptr1), [ strid_row ] "r"(strid_row)
: "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13",
"v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
#else
for (int tr = 0; tr < C8NUM; ++tr) {
for (int tc = 0; tc < C8NUM; ++tc) {
dst_ptr1[tr * C8NUM + tc] = (float16_t)(src_ptr1[tc * row + tr]);
}
}
#endif
}
for (; ri < row; ++ri) {
const float *src_ptr1 = src + ci * row;
float16_t *dst_ptr1 = dst_ptr + ci * row;
for (int tc = 0; tc < C8NUM; ++tc) {
dst_ptr1[ri * C8NUM + tc] = (float16_t)(src_ptr1[tc * row + ri]);
}
}
}
for (int r = 0; r < row; r++) {
for (int tc = ci; tc < col; tc++) {
int cd8 = tc / C8NUM;
int cm8 = tc % C8NUM;
dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = (float16_t)(src[tc * row + r]);
}
}
}
void ColMajor2Row8MajorFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col, bool src_float16) {
if (src_float16) {
const float16_t *src = (const float16_t *)src_ptr;
for (; ci < col_c8; ci += C8NUM) {
int ri = 0;
for (; ri < row_c8; ri += C8NUM) {
const float16_t *src_ptr1 = src + ci * row + ri;
float16_t *dst_ptr1 = dst_ptr + ci * row + ri * C8NUM;
#ifdef ENABLE_ARM64
size_t strid_row = row * 2;
asm volatile(
"mov x10, %[src_ptr1]\n"
"mov x11, %[dst_ptr1]\n"
"mov x12, %[strid_row]\n"
"ld1 {v0.8h}, [x10], x12\n"
"ld1 {v1.8h}, [x10], x12\n"
"ld1 {v2.8h}, [x10], x12\n"
"ld1 {v3.8h}, [x10], x12\n"
"ld1 {v4.8h}, [x10], x12\n"
"ld1 {v5.8h}, [x10], x12\n"
"ld1 {v6.8h}, [x10], x12\n"
"ld1 {v7.8h}, [x10], x12\n"
"zip1 v8.8h, v0.8h, v1.8h\n"
"zip1 v9.8h, v2.8h, v3.8h\n"
"zip1 v10.8h, v4.8h, v5.8h\n"
"zip1 v11.8h, v6.8h, v7.8h\n"
"trn1 v12.4s, v8.4s, v9.4s\n"
"trn1 v14.4s, v10.4s, v11.4s\n"
"trn2 v13.4s, v8.4s, v9.4s\n"
"trn2 v15.4s, v10.4s, v11.4s\n"
"trn1 v16.2d, v12.2d, v14.2d\n"
"trn2 v18.2d, v12.2d, v14.2d\n"
"trn1 v17.2d, v13.2d, v15.2d\n"
"trn2 v19.2d, v13.2d, v15.2d\n"
"zip2 v8.8h, v0.8h, v1.8h\n"
"zip2 v9.8h, v2.8h, v3.8h\n"
"zip2 v10.8h, v4.8h, v5.8h\n"
"zip2 v11.8h, v6.8h, v7.8h\n"
"trn1 v12.4s, v8.4s, v9.4s\n"
"trn1 v14.4s, v10.4s, v11.4s\n"
"trn2 v13.4s, v8.4s, v9.4s\n"
"trn2 v15.4s, v10.4s, v11.4s\n"
"trn1 v20.2d, v12.2d, v14.2d\n"
"trn2 v22.2d, v12.2d, v14.2d\n"
"trn1 v21.2d, v13.2d, v15.2d\n"
"trn2 v23.2d, v13.2d, v15.2d\n"
"st1 {v16.8h}, [x11], #16\n"
"st1 {v17.8h}, [x11], #16\n"
"st1 {v18.8h}, [x11], #16\n"
"st1 {v19.8h}, [x11], #16\n"
"st1 {v20.8h}, [x11], #16\n"
"st1 {v21.8h}, [x11], #16\n"
"st1 {v22.8h}, [x11], #16\n"
"st1 {v23.8h}, [x11], #16\n"
:
: [ dst_ptr1 ] "r"(dst_ptr1), [ src_ptr1 ] "r"(src_ptr1), [ strid_row ] "r"(strid_row)
: "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13",
"v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
#else
for (int tr = 0; tr < C8NUM; ++tr) {
for (int tc = 0; tc < C8NUM; ++tc) {
dst_ptr1[tr * C8NUM + tc] = src_ptr1[tc * row + tr];
}
}
#endif
}
for (; ri < row; ++ri) {
const float16_t *src_ptr1 = src + ci * row;
float16_t *dst_ptr1 = dst_ptr + ci * row;
for (int tc = 0; tc < C8NUM; ++tc) {
dst_ptr1[ri * C8NUM + tc] = src_ptr1[tc * row + ri];
}
}
}
for (int r = 0; r < row; r++) {
for (int tc = ci; tc < col; tc++) {
int cd8 = tc / C8NUM;
int cm8 = tc % C8NUM;
dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = src[tc * row + r];
}
}
Col2Row8SrcFromFp16(src_ptr, dst_ptr, row, col);
} else {
const float *src = (const float *)src_ptr;
for (; ci < col_c8; ci += C8NUM) {
int ri = 0;
for (; ri < row_c8; ri += C8NUM) {
const float *src_ptr1 = src + ci * row + ri;
float16_t *dst_ptr1 = dst_ptr + ci * row + ri * C8NUM;
#ifdef ENABLE_ARM64
size_t strid_row = row * 4;
asm volatile(
"mov x10, %[src_ptr1]\n"
"mov x11, %[dst_ptr1]\n"
"mov x12, %[strid_row]\n"
"ld1 {v8.4s, v9.4s}, [x10], x12\n"
"ld1 {v10.4s, v11.4s}, [x10], x12\n"
"ld1 {v12.4s, v13.4s}, [x10], x12\n"
"ld1 {v14.4s, v15.4s}, [x10], x12\n"
"ld1 {v16.4s, v17.4s}, [x10], x12\n"
"ld1 {v18.4s, v19.4s}, [x10], x12\n"
"ld1 {v20.4s, v21.4s}, [x10], x12\n"
"ld1 {v22.4s, v23.4s}, [x10], x12\n"
"fcvtn v0.4h, v8.4s\n"
"fcvtn2 v0.8h, v9.4s\n"
"fcvtn v1.4h, v10.4s\n"
"fcvtn2 v1.8h, v11.4s\n"
"fcvtn v2.4h, v12.4s\n"
"fcvtn2 v2.8h, v13.4s\n"
"fcvtn v3.4h, v14.4s\n"
"fcvtn2 v3.8h, v15.4s\n"
"fcvtn v4.4h, v16.4s\n"
"fcvtn2 v4.8h, v17.4s\n"
"fcvtn v5.4h, v18.4s\n"
"fcvtn2 v5.8h, v19.4s\n"
"fcvtn v6.4h, v20.4s\n"
"fcvtn2 v6.8h, v21.4s\n"
"fcvtn v7.4h, v22.4s\n"
"fcvtn2 v7.8h, v23.4s\n"
"zip1 v8.8h, v0.8h, v1.8h\n"
"zip1 v9.8h, v2.8h, v3.8h\n"
"zip1 v10.8h, v4.8h, v5.8h\n"
"zip1 v11.8h, v6.8h, v7.8h\n"
"trn1 v12.4s, v8.4s, v9.4s\n"
"trn1 v14.4s, v10.4s, v11.4s\n"
"trn2 v13.4s, v8.4s, v9.4s\n"
"trn2 v15.4s, v10.4s, v11.4s\n"
"trn1 v16.2d, v12.2d, v14.2d\n"
"trn2 v18.2d, v12.2d, v14.2d\n"
"trn1 v17.2d, v13.2d, v15.2d\n"
"trn2 v19.2d, v13.2d, v15.2d\n"
"zip2 v8.8h, v0.8h, v1.8h\n"
"zip2 v9.8h, v2.8h, v3.8h\n"
"zip2 v10.8h, v4.8h, v5.8h\n"
"zip2 v11.8h, v6.8h, v7.8h\n"
"trn1 v12.4s, v8.4s, v9.4s\n"
"trn1 v14.4s, v10.4s, v11.4s\n"
"trn2 v13.4s, v8.4s, v9.4s\n"
"trn2 v15.4s, v10.4s, v11.4s\n"
"trn1 v20.2d, v12.2d, v14.2d\n"
"trn2 v22.2d, v12.2d, v14.2d\n"
"trn1 v21.2d, v13.2d, v15.2d\n"
"trn2 v23.2d, v13.2d, v15.2d\n"
"st1 {v16.8h}, [x11], #16\n"
"st1 {v17.8h}, [x11], #16\n"
"st1 {v18.8h}, [x11], #16\n"
"st1 {v19.8h}, [x11], #16\n"
"st1 {v20.8h}, [x11], #16\n"
"st1 {v21.8h}, [x11], #16\n"
"st1 {v22.8h}, [x11], #16\n"
"st1 {v23.8h}, [x11], #16\n"
:
: [ dst_ptr1 ] "r"(dst_ptr1), [ src_ptr1 ] "r"(src_ptr1), [ strid_row ] "r"(strid_row)
: "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13",
"v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
#else
for (int tr = 0; tr < C8NUM; ++tr) {
for (int tc = 0; tc < C8NUM; ++tc) {
dst_ptr1[tr * C8NUM + tc] = (float16_t)(src_ptr1[tc * row + tr]);
}
}
#endif
}
for (; ri < row; ++ri) {
const float *src_ptr1 = src + ci * row;
float16_t *dst_ptr1 = dst_ptr + ci * row;
for (int tc = 0; tc < C8NUM; ++tc) {
dst_ptr1[ri * C8NUM + tc] = (float16_t)(src_ptr1[tc * row + ri]);
}
}
}
for (int r = 0; r < row; r++) {
for (int tc = ci; tc < col; tc++) {
int cd8 = tc / C8NUM;
int cm8 = tc % C8NUM;
dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = (float16_t)(src[tc * row + r]);
}
}
Col2Row8SrcFromFp32(src_ptr, dst_ptr, row, col);
}
return;
}
@ -274,126 +285,129 @@ void MatVecMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const f
MatVecMulFp16Neon64(a, b, c, bias, (int)act_type, depth, col);
}
static void Row2Col16Block16(const float16_t *src_ptr, float16_t *dst_ptr, size_t col) {
size_t stride = col * 2;
asm volatile(
"mov x10, %[src_c]\n"
"mov x11, %[dst_c]\n"
"ld1 {v0.8h}, [x10], %[stride]\n"
"ld1 {v1.8h}, [x10], %[stride]\n"
"ld1 {v2.8h}, [x10], %[stride]\n"
"ld1 {v3.8h}, [x10], %[stride]\n"
"ld1 {v4.8h}, [x10], %[stride]\n"
"ld1 {v5.8h}, [x10], %[stride]\n"
"ld1 {v6.8h}, [x10], %[stride]\n"
"ld1 {v7.8h}, [x10], %[stride]\n"
"zip1 v16.8h, v0.8h, v1.8h\n"
"zip1 v17.8h, v2.8h, v3.8h\n"
"zip1 v18.8h, v4.8h, v5.8h\n"
"zip1 v19.8h, v6.8h, v7.8h\n"
"ld1 {v8.8h}, [x10], %[stride]\n"
"ld1 {v9.8h}, [x10], %[stride]\n"
"ld1 {v10.8h}, [x10], %[stride]\n"
"ld1 {v11.8h}, [x10], %[stride]\n"
"ld1 {v12.8h}, [x10], %[stride]\n"
"ld1 {v13.8h}, [x10], %[stride]\n"
"ld1 {v14.8h}, [x10], %[stride]\n"
"ld1 {v15.8h}, [x10], %[stride]\n"
"trn1 v20.4s, v16.4s, v17.4s\n"
"trn2 v21.4s, v16.4s, v17.4s\n"
"trn1 v22.4s, v18.4s, v19.4s\n"
"trn2 v23.4s, v18.4s, v19.4s\n"
"trn1 v24.2d, v20.2d, v22.2d\n"
"trn2 v25.2d, v20.2d, v22.2d\n"
"trn1 v26.2d, v21.2d, v23.2d\n"
"trn2 v27.2d, v21.2d, v23.2d\n"
"zip1 v16.8h, v8.8h, v9.8h\n"
"zip1 v17.8h, v10.8h, v11.8h\n"
"zip1 v18.8h, v12.8h, v13.8h\n"
"zip1 v19.8h, v14.8h, v15.8h\n"
"trn1 v20.4s, v16.4s, v17.4s\n"
"trn2 v21.4s, v16.4s, v17.4s\n"
"trn1 v22.4s, v18.4s, v19.4s\n"
"trn2 v23.4s, v18.4s, v19.4s\n"
"trn1 v28.2d, v20.2d, v22.2d\n"
"trn2 v29.2d, v20.2d, v22.2d\n"
"trn1 v30.2d, v21.2d, v23.2d\n"
"trn2 v31.2d, v21.2d, v23.2d\n"
"st1 {v24.8h}, [x11], #16\n"
"st1 {v28.8h}, [x11], #16\n"
"st1 {v26.8h}, [x11], #16\n"
"st1 {v30.8h}, [x11], #16\n"
"st1 {v25.8h}, [x11], #16\n"
"st1 {v29.8h}, [x11], #16\n"
"st1 {v27.8h}, [x11], #16\n"
"st1 {v31.8h}, [x11], #16\n"
"zip2 v16.8h, v0.8h, v1.8h\n"
"zip2 v17.8h, v2.8h, v3.8h\n"
"zip2 v18.8h, v4.8h, v5.8h\n"
"zip2 v19.8h, v6.8h, v7.8h\n"
"trn1 v20.4s, v16.4s, v17.4s\n"
"trn2 v21.4s, v16.4s, v17.4s\n"
"trn1 v22.4s, v18.4s, v19.4s\n"
"trn2 v23.4s, v18.4s, v19.4s\n"
"trn1 v24.2d, v20.2d, v22.2d\n"
"trn2 v25.2d, v20.2d, v22.2d\n"
"trn1 v26.2d, v21.2d, v23.2d\n"
"trn2 v27.2d, v21.2d, v23.2d\n"
"zip2 v16.8h, v8.8h, v9.8h\n"
"zip2 v17.8h, v10.8h, v11.8h\n"
"zip2 v18.8h, v12.8h, v13.8h\n"
"zip2 v19.8h, v14.8h, v15.8h\n"
"trn1 v20.4s, v16.4s, v17.4s\n"
"trn2 v21.4s, v16.4s, v17.4s\n"
"trn1 v22.4s, v18.4s, v19.4s\n"
"trn2 v23.4s, v18.4s, v19.4s\n"
"trn1 v28.2d, v20.2d, v22.2d\n"
"trn2 v29.2d, v20.2d, v22.2d\n"
"trn1 v30.2d, v21.2d, v23.2d\n"
"trn2 v31.2d, v21.2d, v23.2d\n"
"st1 {v24.8h}, [x11], #16\n"
"st1 {v28.8h}, [x11], #16\n"
"st1 {v26.8h}, [x11], #16\n"
"st1 {v30.8h}, [x11], #16\n"
"st1 {v25.8h}, [x11], #16\n"
"st1 {v29.8h}, [x11], #16\n"
"st1 {v27.8h}, [x11], #16\n"
"st1 {v31.8h}, [x11], #16\n"
:
: [ dst_c ] "r"(dst_ptr), [ src_c ] "r"(src_ptr), [ stride ] "r"(stride)
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
"v31");
}
void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) {
size_t row_up_16 = UP_ROUND(row, C16NUM);
size_t row16 = row / C16NUM * C16NUM;
size_t col8 = col / C8NUM * C8NUM;
const float16_t *src_r = src_ptr;
float16_t *dst_r = dst_ptr;
size_t ri = 0;
// find 16 block unit
for (; ri < row16; ri += C16NUM) {
size_t ci = 0;
for (; ci < col8; ci += C8NUM) {
const float16_t *src_c = src_r + ci;
float16_t *dst_c = dst_r + ci * C16NUM;
#ifdef ENABLE_ARM64
size_t stride = col * 2;
asm volatile(
"mov x10, %[src_c]\n"
"mov x11, %[dst_c]\n"
"ld1 {v0.8h}, [x10], %[stride]\n"
"ld1 {v1.8h}, [x10], %[stride]\n"
"ld1 {v2.8h}, [x10], %[stride]\n"
"ld1 {v3.8h}, [x10], %[stride]\n"
"ld1 {v4.8h}, [x10], %[stride]\n"
"ld1 {v5.8h}, [x10], %[stride]\n"
"ld1 {v6.8h}, [x10], %[stride]\n"
"ld1 {v7.8h}, [x10], %[stride]\n"
"zip1 v16.8h, v0.8h, v1.8h\n"
"zip1 v17.8h, v2.8h, v3.8h\n"
"zip1 v18.8h, v4.8h, v5.8h\n"
"zip1 v19.8h, v6.8h, v7.8h\n"
"ld1 {v8.8h}, [x10], %[stride]\n"
"ld1 {v9.8h}, [x10], %[stride]\n"
"ld1 {v10.8h}, [x10], %[stride]\n"
"ld1 {v11.8h}, [x10], %[stride]\n"
"ld1 {v12.8h}, [x10], %[stride]\n"
"ld1 {v13.8h}, [x10], %[stride]\n"
"ld1 {v14.8h}, [x10], %[stride]\n"
"ld1 {v15.8h}, [x10], %[stride]\n"
"trn1 v20.4s, v16.4s, v17.4s\n"
"trn2 v21.4s, v16.4s, v17.4s\n"
"trn1 v22.4s, v18.4s, v19.4s\n"
"trn2 v23.4s, v18.4s, v19.4s\n"
"trn1 v24.2d, v20.2d, v22.2d\n"
"trn2 v25.2d, v20.2d, v22.2d\n"
"trn1 v26.2d, v21.2d, v23.2d\n"
"trn2 v27.2d, v21.2d, v23.2d\n"
"zip1 v16.8h, v8.8h, v9.8h\n"
"zip1 v17.8h, v10.8h, v11.8h\n"
"zip1 v18.8h, v12.8h, v13.8h\n"
"zip1 v19.8h, v14.8h, v15.8h\n"
"trn1 v20.4s, v16.4s, v17.4s\n"
"trn2 v21.4s, v16.4s, v17.4s\n"
"trn1 v22.4s, v18.4s, v19.4s\n"
"trn2 v23.4s, v18.4s, v19.4s\n"
"trn1 v28.2d, v20.2d, v22.2d\n"
"trn2 v29.2d, v20.2d, v22.2d\n"
"trn1 v30.2d, v21.2d, v23.2d\n"
"trn2 v31.2d, v21.2d, v23.2d\n"
"st1 {v24.8h}, [x11], #16\n"
"st1 {v28.8h}, [x11], #16\n"
"st1 {v26.8h}, [x11], #16\n"
"st1 {v30.8h}, [x11], #16\n"
"st1 {v25.8h}, [x11], #16\n"
"st1 {v29.8h}, [x11], #16\n"
"st1 {v27.8h}, [x11], #16\n"
"st1 {v31.8h}, [x11], #16\n"
"zip2 v16.8h, v0.8h, v1.8h\n"
"zip2 v17.8h, v2.8h, v3.8h\n"
"zip2 v18.8h, v4.8h, v5.8h\n"
"zip2 v19.8h, v6.8h, v7.8h\n"
"trn1 v20.4s, v16.4s, v17.4s\n"
"trn2 v21.4s, v16.4s, v17.4s\n"
"trn1 v22.4s, v18.4s, v19.4s\n"
"trn2 v23.4s, v18.4s, v19.4s\n"
"trn1 v24.2d, v20.2d, v22.2d\n"
"trn2 v25.2d, v20.2d, v22.2d\n"
"trn1 v26.2d, v21.2d, v23.2d\n"
"trn2 v27.2d, v21.2d, v23.2d\n"
"zip2 v16.8h, v8.8h, v9.8h\n"
"zip2 v17.8h, v10.8h, v11.8h\n"
"zip2 v18.8h, v12.8h, v13.8h\n"
"zip2 v19.8h, v14.8h, v15.8h\n"
"trn1 v20.4s, v16.4s, v17.4s\n"
"trn2 v21.4s, v16.4s, v17.4s\n"
"trn1 v22.4s, v18.4s, v19.4s\n"
"trn2 v23.4s, v18.4s, v19.4s\n"
"trn1 v28.2d, v20.2d, v22.2d\n"
"trn2 v29.2d, v20.2d, v22.2d\n"
"trn1 v30.2d, v21.2d, v23.2d\n"
"trn2 v31.2d, v21.2d, v23.2d\n"
"st1 {v24.8h}, [x11], #16\n"
"st1 {v28.8h}, [x11], #16\n"
"st1 {v26.8h}, [x11], #16\n"
"st1 {v30.8h}, [x11], #16\n"
"st1 {v25.8h}, [x11], #16\n"
"st1 {v29.8h}, [x11], #16\n"
"st1 {v27.8h}, [x11], #16\n"
"st1 {v31.8h}, [x11], #16\n"
:
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
"v30", "v31");
Row2Col16Block16(src_c, dst_c, col);
#else
for (int tr = 0; tr < C16NUM; tr++) {
for (int tc = 0; tc < C8NUM; tc++) {
@ -413,7 +427,7 @@ void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, si
dst_r += C16NUM * col;
}
for (; ri < row; ri++) {
for (size_t i = 0; i < col; i++) {
for (size_t i = 0; i < col; ++i) {
dst_r[i * C16NUM] = src_r[i];
}
src_r += col;

View File

@ -40,6 +40,9 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) {
int start_index = thread_id * cal_num;
int real_cal_num = (output_count - start_index) < cal_num ? (output_count - start_index) : cal_num;
if (real_cal_num <= 0) {
return;
}
float *gemm_input = packed_input + task_id * deep * cal_num;
float *col_major_gemm_input = col_major_input + task_id * deep * cal_num;
size_t packed_input_size = deep * cal_num * sizeof(float);

View File

@ -56,6 +56,9 @@ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const
int out_tile_index = thread_id * tile_num;
int cal_num = output_count - out_tile_index;
cal_num = cal_num > tile_num ? tile_num : cal_num;
if (cal_num <= 0) {
return;
}
WinogradInputTransform(input_data + in_batch_offset, trans_input + task_id * trans_input_offset,
tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param,
in_func);

View File

@ -36,7 +36,6 @@ ConvolutionBaseCPUKernel::~ConvolutionBaseCPUKernel() {
}
void ConvolutionBaseCPUKernel::FreeQuantParam() {
ConvQuantArg *conv_quant_arg_ = &conv_param_->conv_quant_arg_;
if (conv_quant_arg_ == nullptr) {
return;
}

View File

@ -44,7 +44,10 @@ class ConvolutionDelegateFP16CPUKernel : public LiteKernel {
void FreeCopiedData();
int Init() override;
int ReSize() override;
int Run() override { return fp16_conv_kernel_->Run(); }
int Run() override {
fp16_conv_kernel_->set_name(name_);
return fp16_conv_kernel_->Run();
}
private:
uint8_t need_free_ = 0b00;

View File

@ -102,6 +102,13 @@ int ConvolutionFP16CPUKernel::Init() {
return RET_OK;
}
void ConvolutionFP16CPUKernel::AdjustNumberOfThread() {
auto out_tensor = out_tensors_.front();
int out_plane = out_tensor->Height() * out_tensor->Width();
thread_count_ = MSMIN(ctx_->thread_num_, UP_DIV(out_plane, C16NUM));
conv_param_->thread_num_ = thread_count_;
}
int ConvolutionFP16CPUKernel::ReSize() {
auto ret = ConvolutionBaseCPUKernel::CheckResizeValid();
if (ret != RET_OK) {

View File

@ -44,6 +44,7 @@ class ConvolutionFP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
int RunImpl(int task_id);
int InitWeightBias();
int InitTmpBuffer();
void AdjustNumberOfThread();
private:
void FreeTmpBuffer() {

View File

@ -108,7 +108,6 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() {
int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() {
const int cal_num = 16;
int channel_out = conv_param_->output_channel_;
int oc8 = UP_DIV(channel_out, C8NUM);
size_t tile_buffer_size =
thread_count_ * cal_num * input_unit_ * input_unit_ * conv_param_->input_channel_ * sizeof(float16_t);
@ -118,8 +117,8 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() {
return RET_ERROR;
}
gemm_out_ = reinterpret_cast<float16_t *>(
ctx_->allocator->Malloc(thread_count_ * cal_num * input_unit_ * input_unit_ * oc8 * C8NUM * sizeof(float16_t)));
gemm_out_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(
thread_count_ * cal_num * input_unit_ * input_unit_ * UP_ROUND(channel_out, C8NUM) * sizeof(float16_t)));
if (gemm_out_ == nullptr) {
MS_LOG(ERROR) << "malloc gemm_out_ failed.";
return RET_ERROR;
@ -174,6 +173,13 @@ int ConvolutionWinogradFP16CPUKernel::Init() {
return RET_OK;
}
void ConvolutionWinogradFP16CPUKernel::AdjustNumberOfThread() {
auto out_tensor = out_tensors_.front();
int cal_plane = UP_DIV(out_tensor->Height(), output_unit_) * UP_DIV(out_tensor->Width(), output_unit_);
thread_count_ = MSMIN(ctx_->thread_num_, UP_DIV(cal_plane, C8NUM));
conv_param_->thread_num_ = thread_count_;
}
int ConvolutionWinogradFP16CPUKernel::ReSize() {
auto ret = ConvolutionBaseCPUKernel::CheckResizeValid();
if (ret != RET_OK) {
@ -190,6 +196,7 @@ int ConvolutionWinogradFP16CPUKernel::ReSize() {
MS_LOG(ERROR) << "ConfigInputOutput failed.";
return RET_ERROR;
}
AdjustNumberOfThread();
return RET_OK;
}

View File

@ -52,6 +52,7 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
int InitTmpBuffer();
int ConfigInputOutput();
int WinogradFilterTransformFp16(const float16_t *weight_data, float *matrix_g, float *matrix_gt, int oc_block);
void AdjustNumberOfThread();
private:
void FreeTmpBuffer() {

View File

@ -48,16 +48,9 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() {
conv_param_->input_channel_ = in_channel;
conv_param_->output_channel_ = out_channel;
int oc4 = UP_DIV(out_channel, C4NUM);
#ifdef ENABLE_AVX
const int oc_block = C16NUM;
#else
const int oc_block = C8NUM;
#endif
int oc_block_num = UP_DIV(out_channel, oc_block);
// set data
auto trans_matrix_data_size = input_unit_ * input_unit_ * in_channel * oc_block_num * oc_block * sizeof(float);
auto trans_matrix_data_size =
input_unit_ * input_unit_ * in_channel * UP_ROUND(out_channel, oc_block_) * sizeof(float);
if (trans_weight_ == nullptr) {
trans_weight_ = reinterpret_cast<float *>(malloc(trans_matrix_data_size));
if (trans_weight_ == nullptr) {
@ -83,14 +76,15 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() {
MS_LOG(ERROR) << "get matrix g from CookToomFilter failed.";
return ret;
}
ret = WinogradFilterTransform(origin_weight_, matrix_g, matrix_gt, oc_block);
ret = WinogradFilterTransform(origin_weight_, matrix_g, matrix_gt, oc_block_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "winograd filter transform failed.";
return ret;
}
// init bias
size_t new_bias_size = oc4 * C4NUM * sizeof(float);
size_t new_bias_size = UP_ROUND(out_channel, C4NUM) * sizeof(float);
bias_data_ = malloc(new_bias_size);
if (bias_data_ == nullptr) {
bias_data_ = reinterpret_cast<float *>(malloc(new_bias_size));
if (bias_data_ == nullptr) {
@ -98,31 +92,30 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() {
return RET_MEMORY_FAILED;
}
}
memset(bias_data_, 0, new_bias_size);
if (in_tensors_.size() == kInputSize2) {
memcpy(bias_data_, origin_bias_, out_channel * sizeof(float));
size_t origin_size = out_channel * sizeof(float);
memcpy(bias_data_, origin_bias_, origin_size);
memset(reinterpret_cast<float *>(bias_data_) + out_channel, 0, new_bias_size - origin_size);
} else {
MS_ASSERT(in_tensors_.size() == kInputSize1);
memset(bias_data_, 0, new_bias_size);
}
return RET_OK;
}
int ConvolutionWinogradCPUKernel::InitTmpBuffer() {
int channel_out = conv_param_->output_channel_;
int oc8 = UP_DIV(channel_out, C8NUM);
int tile_num = C12NUM;
MS_ASSERT(ctx_->allocator != nullptr);
size_t tile_buffer_size =
thread_count_ * tile_num * input_unit_ * input_unit_ * conv_param_->input_channel_ * sizeof(float);
thread_count_ * tile_num_ * input_unit_ * input_unit_ * conv_param_->input_channel_ * sizeof(float);
trans_input_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(tile_buffer_size));
if (trans_input_ == nullptr) {
MS_LOG(ERROR) << "malloc trans_input_ failed.";
return RET_MEMORY_FAILED;
}
int oc8 = UP_ROUND(conv_param_->output_channel_, C8NUM);
gemm_out_ = reinterpret_cast<float *>(
ctx_->allocator->Malloc(thread_count_ * tile_num * input_unit_ * input_unit_ * oc8 * C8NUM * sizeof(float)));
ctx_->allocator->Malloc(thread_count_ * tile_num_ * input_unit_ * input_unit_ * oc8 * sizeof(float)));
if (gemm_out_ == nullptr) {
MS_LOG(ERROR) << "malloc gemm_out_ failed.";
return RET_ERROR;
@ -136,7 +129,7 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() {
}
col_buffer_ = reinterpret_cast<float *>(
ctx_->allocator->Malloc(thread_count_ * tile_num * conv_param_->input_channel_ * sizeof(float)));
ctx_->allocator->Malloc(thread_count_ * tile_num_ * conv_param_->input_channel_ * sizeof(float)));
if (col_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc col_buffer_ failed.";
return RET_ERROR;
@ -164,10 +157,17 @@ int ConvolutionWinogradCPUKernel::ConfigInputOutput() {
}
int ConvolutionWinogradCPUKernel::Init() {
tile_num_ = C12NUM;
#ifdef ENABLE_AVX
oc_block_ = C16NUM;
#else
oc_block_ = C8NUM;
#endif
kernel_unit_ = conv_param_->kernel_h_;
input_unit_ = output_unit_ + kernel_unit_ - 1;
conv_param_->input_unit_ = input_unit_;
conv_param_->output_unit_ = output_unit_;
auto ret = InitWeightBias();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init weight bias failed.";
@ -197,8 +197,8 @@ int ConvolutionWinogradCPUKernel::ReSize() {
int ConvolutionWinogradCPUKernel::RunImpl(int task_id) {
auto input_tensor = in_tensors_.at(kInputIndex);
auto ori_input_data = reinterpret_cast<float *>(input_tensor->MutableData());
auto output_data = reinterpret_cast<float *>(out_tensors_.front()->MutableData());
auto ori_input_data = reinterpret_cast<float *>(input_tensor->data_c());
auto output_data = reinterpret_cast<float *>(out_tensors_.front()->data_c());
ConvWinogardFp32(ori_input_data, trans_weight_, reinterpret_cast<const float *>(bias_data_), output_data,
tmp_buffer_address_list_, task_id, conv_param_, in_func_, out_func_);
return RET_OK;

View File

@ -70,9 +70,11 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel {
col_buffer_ = nullptr;
}
}
int kernel_unit_;
int input_unit_;
int kernel_unit_{0};
int input_unit_{0};
int output_unit_;
int oc_block_{0};
int tile_num_{0};
float *origin_weight_; // do not free
float *origin_bias_; // do not free
float *tmp_data_ = nullptr;