forked from mindspore-Ecosystem/mindspore
!11977 [MS][LITE][CPU]change the parallel strategy for fp16 winograd
From: @fuzhiye Reviewed-by: @zhang_xue_tong,@hangangqiang Signed-off-by: @zhang_xue_tong
This commit is contained in:
commit
ca3f916c1e
|
@ -160,7 +160,9 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa
|
|||
int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_);
|
||||
int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_);
|
||||
int output_count = out_w_block * out_h_block;
|
||||
int output_tile_count = UP_DIV(output_count, tile_num);
|
||||
int per_thread_num = UP_DIV(output_count, conv_param->thread_num_);
|
||||
int real_tile = per_thread_num < tile_num ? per_thread_num : tile_num;
|
||||
int output_tile_count = UP_DIV(output_count, real_tile);
|
||||
int oc8 = UP_DIV(conv_param->output_channel_, C8NUM);
|
||||
int input_unit_square = conv_param->input_unit_ * conv_param->input_unit_;
|
||||
|
||||
|
@ -178,9 +180,12 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa
|
|||
int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_;
|
||||
int out_batch_offset = b * conv_param->output_channel_ * conv_param->output_h_ * conv_param->output_w_;
|
||||
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) {
|
||||
int out_tile_index = thread_id * tile_num;
|
||||
int cal_num = output_count - thread_id * tile_num;
|
||||
cal_num = cal_num > tile_num ? tile_num : cal_num;
|
||||
int out_tile_index = thread_id * real_tile;
|
||||
int cal_num = output_count - thread_id * real_tile;
|
||||
cal_num = cal_num > real_tile ? real_tile : cal_num;
|
||||
if (cal_num <= 0) {
|
||||
return;
|
||||
}
|
||||
WinogradInputTransformFp16(input_data + in_batch_offset, trans_input + task_id * trans_input_offset,
|
||||
tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param,
|
||||
in_func);
|
||||
|
@ -189,7 +194,7 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa
|
|||
float16_t *dst_ptr = gemm_out + task_id * gemm_out_offset;
|
||||
float16_t *tmp_col_ptr = col_buffer + task_id * col_buffer_offset;
|
||||
for (int i = 0; i < input_unit_square; ++i) {
|
||||
RowMajor2Col16MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, tile_num, in_channel);
|
||||
RowMajor2Col16MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel);
|
||||
MatMulFp16(tmp_col_ptr, trans_weight + i * in_channel * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, in_channel,
|
||||
cal_num, oc8 * C8NUM, input_unit_square, OutType_TileC8);
|
||||
}
|
||||
|
|
|
@ -16,201 +16,212 @@
|
|||
|
||||
#include "nnacl/fp16/matmul_fp16.h"
|
||||
|
||||
void ColMajor2Row8MajorFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col, bool src_float16) {
|
||||
static void Col2Row8SrcFromFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col) {
|
||||
int row_c8 = row / C8NUM * C8NUM;
|
||||
int col_c8 = col / C8NUM * C8NUM;
|
||||
const float16_t *src = (const float16_t *)src_ptr;
|
||||
int ci = 0;
|
||||
for (; ci < col_c8; ci += C8NUM) {
|
||||
int ri = 0;
|
||||
for (; ri < row_c8; ri += C8NUM) {
|
||||
const float16_t *src_ptr1 = src + ci * row + ri;
|
||||
float16_t *dst_ptr1 = dst_ptr + ci * row + ri * C8NUM;
|
||||
#ifdef ENABLE_ARM64
|
||||
size_t strid_row = row * 2;
|
||||
asm volatile(
|
||||
"mov x10, %[src_ptr1]\n"
|
||||
"mov x11, %[dst_ptr1]\n"
|
||||
"mov x12, %[strid_row]\n"
|
||||
"ld1 {v0.8h}, [x10], x12\n"
|
||||
"ld1 {v1.8h}, [x10], x12\n"
|
||||
"ld1 {v2.8h}, [x10], x12\n"
|
||||
"ld1 {v3.8h}, [x10], x12\n"
|
||||
"ld1 {v4.8h}, [x10], x12\n"
|
||||
"ld1 {v5.8h}, [x10], x12\n"
|
||||
"ld1 {v6.8h}, [x10], x12\n"
|
||||
"ld1 {v7.8h}, [x10], x12\n"
|
||||
|
||||
"zip1 v8.8h, v0.8h, v1.8h\n"
|
||||
"zip1 v9.8h, v2.8h, v3.8h\n"
|
||||
"zip1 v10.8h, v4.8h, v5.8h\n"
|
||||
"zip1 v11.8h, v6.8h, v7.8h\n"
|
||||
|
||||
"trn1 v12.4s, v8.4s, v9.4s\n"
|
||||
"trn1 v14.4s, v10.4s, v11.4s\n"
|
||||
"trn2 v13.4s, v8.4s, v9.4s\n"
|
||||
"trn2 v15.4s, v10.4s, v11.4s\n"
|
||||
|
||||
"trn1 v16.2d, v12.2d, v14.2d\n"
|
||||
"trn2 v18.2d, v12.2d, v14.2d\n"
|
||||
"trn1 v17.2d, v13.2d, v15.2d\n"
|
||||
"trn2 v19.2d, v13.2d, v15.2d\n"
|
||||
|
||||
"zip2 v8.8h, v0.8h, v1.8h\n"
|
||||
"zip2 v9.8h, v2.8h, v3.8h\n"
|
||||
"zip2 v10.8h, v4.8h, v5.8h\n"
|
||||
"zip2 v11.8h, v6.8h, v7.8h\n"
|
||||
|
||||
"trn1 v12.4s, v8.4s, v9.4s\n"
|
||||
"trn1 v14.4s, v10.4s, v11.4s\n"
|
||||
"trn2 v13.4s, v8.4s, v9.4s\n"
|
||||
"trn2 v15.4s, v10.4s, v11.4s\n"
|
||||
|
||||
"trn1 v20.2d, v12.2d, v14.2d\n"
|
||||
"trn2 v22.2d, v12.2d, v14.2d\n"
|
||||
"trn1 v21.2d, v13.2d, v15.2d\n"
|
||||
"trn2 v23.2d, v13.2d, v15.2d\n"
|
||||
|
||||
"st1 {v16.8h}, [x11], #16\n"
|
||||
"st1 {v17.8h}, [x11], #16\n"
|
||||
"st1 {v18.8h}, [x11], #16\n"
|
||||
"st1 {v19.8h}, [x11], #16\n"
|
||||
"st1 {v20.8h}, [x11], #16\n"
|
||||
"st1 {v21.8h}, [x11], #16\n"
|
||||
"st1 {v22.8h}, [x11], #16\n"
|
||||
"st1 {v23.8h}, [x11], #16\n"
|
||||
:
|
||||
: [ dst_ptr1 ] "r"(dst_ptr1), [ src_ptr1 ] "r"(src_ptr1), [ strid_row ] "r"(strid_row)
|
||||
: "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13",
|
||||
"v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
|
||||
#else
|
||||
for (int tr = 0; tr < C8NUM; ++tr) {
|
||||
for (int tc = 0; tc < C8NUM; ++tc) {
|
||||
dst_ptr1[tr * C8NUM + tc] = src_ptr1[tc * row + tr];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
for (; ri < row; ++ri) {
|
||||
const float16_t *src_ptr1 = src + ci * row;
|
||||
float16_t *dst_ptr1 = dst_ptr + ci * row;
|
||||
for (int tc = 0; tc < C8NUM; ++tc) {
|
||||
dst_ptr1[ri * C8NUM + tc] = src_ptr1[tc * row + ri];
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int r = 0; r < row; r++) {
|
||||
for (int tc = ci; tc < col; tc++) {
|
||||
int cd8 = tc / C8NUM;
|
||||
int cm8 = tc % C8NUM;
|
||||
dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = src[tc * row + r];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void Col2Row8SrcFromFp32(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col) {
|
||||
int row_c8 = row / C8NUM * C8NUM;
|
||||
int col_c8 = col / C8NUM * C8NUM;
|
||||
int ci = 0;
|
||||
const float *src = (const float *)src_ptr;
|
||||
for (; ci < col_c8; ci += C8NUM) {
|
||||
int ri = 0;
|
||||
for (; ri < row_c8; ri += C8NUM) {
|
||||
const float *src_ptr1 = src + ci * row + ri;
|
||||
float16_t *dst_ptr1 = dst_ptr + ci * row + ri * C8NUM;
|
||||
#ifdef ENABLE_ARM64
|
||||
size_t strid_row = row * 4;
|
||||
asm volatile(
|
||||
"mov x10, %[src_ptr1]\n"
|
||||
"mov x11, %[dst_ptr1]\n"
|
||||
"mov x12, %[strid_row]\n"
|
||||
"ld1 {v8.4s, v9.4s}, [x10], x12\n"
|
||||
"ld1 {v10.4s, v11.4s}, [x10], x12\n"
|
||||
"ld1 {v12.4s, v13.4s}, [x10], x12\n"
|
||||
"ld1 {v14.4s, v15.4s}, [x10], x12\n"
|
||||
"ld1 {v16.4s, v17.4s}, [x10], x12\n"
|
||||
"ld1 {v18.4s, v19.4s}, [x10], x12\n"
|
||||
"ld1 {v20.4s, v21.4s}, [x10], x12\n"
|
||||
"ld1 {v22.4s, v23.4s}, [x10], x12\n"
|
||||
|
||||
"fcvtn v0.4h, v8.4s\n"
|
||||
"fcvtn2 v0.8h, v9.4s\n"
|
||||
"fcvtn v1.4h, v10.4s\n"
|
||||
"fcvtn2 v1.8h, v11.4s\n"
|
||||
"fcvtn v2.4h, v12.4s\n"
|
||||
"fcvtn2 v2.8h, v13.4s\n"
|
||||
"fcvtn v3.4h, v14.4s\n"
|
||||
"fcvtn2 v3.8h, v15.4s\n"
|
||||
"fcvtn v4.4h, v16.4s\n"
|
||||
"fcvtn2 v4.8h, v17.4s\n"
|
||||
"fcvtn v5.4h, v18.4s\n"
|
||||
"fcvtn2 v5.8h, v19.4s\n"
|
||||
"fcvtn v6.4h, v20.4s\n"
|
||||
"fcvtn2 v6.8h, v21.4s\n"
|
||||
"fcvtn v7.4h, v22.4s\n"
|
||||
"fcvtn2 v7.8h, v23.4s\n"
|
||||
|
||||
"zip1 v8.8h, v0.8h, v1.8h\n"
|
||||
"zip1 v9.8h, v2.8h, v3.8h\n"
|
||||
"zip1 v10.8h, v4.8h, v5.8h\n"
|
||||
"zip1 v11.8h, v6.8h, v7.8h\n"
|
||||
|
||||
"trn1 v12.4s, v8.4s, v9.4s\n"
|
||||
"trn1 v14.4s, v10.4s, v11.4s\n"
|
||||
"trn2 v13.4s, v8.4s, v9.4s\n"
|
||||
"trn2 v15.4s, v10.4s, v11.4s\n"
|
||||
|
||||
"trn1 v16.2d, v12.2d, v14.2d\n"
|
||||
"trn2 v18.2d, v12.2d, v14.2d\n"
|
||||
"trn1 v17.2d, v13.2d, v15.2d\n"
|
||||
"trn2 v19.2d, v13.2d, v15.2d\n"
|
||||
|
||||
"zip2 v8.8h, v0.8h, v1.8h\n"
|
||||
"zip2 v9.8h, v2.8h, v3.8h\n"
|
||||
"zip2 v10.8h, v4.8h, v5.8h\n"
|
||||
"zip2 v11.8h, v6.8h, v7.8h\n"
|
||||
|
||||
"trn1 v12.4s, v8.4s, v9.4s\n"
|
||||
"trn1 v14.4s, v10.4s, v11.4s\n"
|
||||
"trn2 v13.4s, v8.4s, v9.4s\n"
|
||||
"trn2 v15.4s, v10.4s, v11.4s\n"
|
||||
|
||||
"trn1 v20.2d, v12.2d, v14.2d\n"
|
||||
"trn2 v22.2d, v12.2d, v14.2d\n"
|
||||
"trn1 v21.2d, v13.2d, v15.2d\n"
|
||||
"trn2 v23.2d, v13.2d, v15.2d\n"
|
||||
|
||||
"st1 {v16.8h}, [x11], #16\n"
|
||||
"st1 {v17.8h}, [x11], #16\n"
|
||||
"st1 {v18.8h}, [x11], #16\n"
|
||||
"st1 {v19.8h}, [x11], #16\n"
|
||||
"st1 {v20.8h}, [x11], #16\n"
|
||||
"st1 {v21.8h}, [x11], #16\n"
|
||||
"st1 {v22.8h}, [x11], #16\n"
|
||||
"st1 {v23.8h}, [x11], #16\n"
|
||||
:
|
||||
: [ dst_ptr1 ] "r"(dst_ptr1), [ src_ptr1 ] "r"(src_ptr1), [ strid_row ] "r"(strid_row)
|
||||
: "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13",
|
||||
"v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
|
||||
#else
|
||||
for (int tr = 0; tr < C8NUM; ++tr) {
|
||||
for (int tc = 0; tc < C8NUM; ++tc) {
|
||||
dst_ptr1[tr * C8NUM + tc] = (float16_t)(src_ptr1[tc * row + tr]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
for (; ri < row; ++ri) {
|
||||
const float *src_ptr1 = src + ci * row;
|
||||
float16_t *dst_ptr1 = dst_ptr + ci * row;
|
||||
for (int tc = 0; tc < C8NUM; ++tc) {
|
||||
dst_ptr1[ri * C8NUM + tc] = (float16_t)(src_ptr1[tc * row + ri]);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int r = 0; r < row; r++) {
|
||||
for (int tc = ci; tc < col; tc++) {
|
||||
int cd8 = tc / C8NUM;
|
||||
int cm8 = tc % C8NUM;
|
||||
dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = (float16_t)(src[tc * row + r]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ColMajor2Row8MajorFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col, bool src_float16) {
|
||||
if (src_float16) {
|
||||
const float16_t *src = (const float16_t *)src_ptr;
|
||||
for (; ci < col_c8; ci += C8NUM) {
|
||||
int ri = 0;
|
||||
for (; ri < row_c8; ri += C8NUM) {
|
||||
const float16_t *src_ptr1 = src + ci * row + ri;
|
||||
float16_t *dst_ptr1 = dst_ptr + ci * row + ri * C8NUM;
|
||||
#ifdef ENABLE_ARM64
|
||||
size_t strid_row = row * 2;
|
||||
asm volatile(
|
||||
"mov x10, %[src_ptr1]\n"
|
||||
"mov x11, %[dst_ptr1]\n"
|
||||
"mov x12, %[strid_row]\n"
|
||||
"ld1 {v0.8h}, [x10], x12\n"
|
||||
"ld1 {v1.8h}, [x10], x12\n"
|
||||
"ld1 {v2.8h}, [x10], x12\n"
|
||||
"ld1 {v3.8h}, [x10], x12\n"
|
||||
"ld1 {v4.8h}, [x10], x12\n"
|
||||
"ld1 {v5.8h}, [x10], x12\n"
|
||||
"ld1 {v6.8h}, [x10], x12\n"
|
||||
"ld1 {v7.8h}, [x10], x12\n"
|
||||
|
||||
"zip1 v8.8h, v0.8h, v1.8h\n"
|
||||
"zip1 v9.8h, v2.8h, v3.8h\n"
|
||||
"zip1 v10.8h, v4.8h, v5.8h\n"
|
||||
"zip1 v11.8h, v6.8h, v7.8h\n"
|
||||
|
||||
"trn1 v12.4s, v8.4s, v9.4s\n"
|
||||
"trn1 v14.4s, v10.4s, v11.4s\n"
|
||||
"trn2 v13.4s, v8.4s, v9.4s\n"
|
||||
"trn2 v15.4s, v10.4s, v11.4s\n"
|
||||
|
||||
"trn1 v16.2d, v12.2d, v14.2d\n"
|
||||
"trn2 v18.2d, v12.2d, v14.2d\n"
|
||||
"trn1 v17.2d, v13.2d, v15.2d\n"
|
||||
"trn2 v19.2d, v13.2d, v15.2d\n"
|
||||
|
||||
"zip2 v8.8h, v0.8h, v1.8h\n"
|
||||
"zip2 v9.8h, v2.8h, v3.8h\n"
|
||||
"zip2 v10.8h, v4.8h, v5.8h\n"
|
||||
"zip2 v11.8h, v6.8h, v7.8h\n"
|
||||
|
||||
"trn1 v12.4s, v8.4s, v9.4s\n"
|
||||
"trn1 v14.4s, v10.4s, v11.4s\n"
|
||||
"trn2 v13.4s, v8.4s, v9.4s\n"
|
||||
"trn2 v15.4s, v10.4s, v11.4s\n"
|
||||
|
||||
"trn1 v20.2d, v12.2d, v14.2d\n"
|
||||
"trn2 v22.2d, v12.2d, v14.2d\n"
|
||||
"trn1 v21.2d, v13.2d, v15.2d\n"
|
||||
"trn2 v23.2d, v13.2d, v15.2d\n"
|
||||
|
||||
"st1 {v16.8h}, [x11], #16\n"
|
||||
"st1 {v17.8h}, [x11], #16\n"
|
||||
"st1 {v18.8h}, [x11], #16\n"
|
||||
"st1 {v19.8h}, [x11], #16\n"
|
||||
"st1 {v20.8h}, [x11], #16\n"
|
||||
"st1 {v21.8h}, [x11], #16\n"
|
||||
"st1 {v22.8h}, [x11], #16\n"
|
||||
"st1 {v23.8h}, [x11], #16\n"
|
||||
:
|
||||
: [ dst_ptr1 ] "r"(dst_ptr1), [ src_ptr1 ] "r"(src_ptr1), [ strid_row ] "r"(strid_row)
|
||||
: "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13",
|
||||
"v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
|
||||
#else
|
||||
for (int tr = 0; tr < C8NUM; ++tr) {
|
||||
for (int tc = 0; tc < C8NUM; ++tc) {
|
||||
dst_ptr1[tr * C8NUM + tc] = src_ptr1[tc * row + tr];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
for (; ri < row; ++ri) {
|
||||
const float16_t *src_ptr1 = src + ci * row;
|
||||
float16_t *dst_ptr1 = dst_ptr + ci * row;
|
||||
for (int tc = 0; tc < C8NUM; ++tc) {
|
||||
dst_ptr1[ri * C8NUM + tc] = src_ptr1[tc * row + ri];
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int r = 0; r < row; r++) {
|
||||
for (int tc = ci; tc < col; tc++) {
|
||||
int cd8 = tc / C8NUM;
|
||||
int cm8 = tc % C8NUM;
|
||||
dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = src[tc * row + r];
|
||||
}
|
||||
}
|
||||
Col2Row8SrcFromFp16(src_ptr, dst_ptr, row, col);
|
||||
} else {
|
||||
const float *src = (const float *)src_ptr;
|
||||
for (; ci < col_c8; ci += C8NUM) {
|
||||
int ri = 0;
|
||||
for (; ri < row_c8; ri += C8NUM) {
|
||||
const float *src_ptr1 = src + ci * row + ri;
|
||||
float16_t *dst_ptr1 = dst_ptr + ci * row + ri * C8NUM;
|
||||
#ifdef ENABLE_ARM64
|
||||
size_t strid_row = row * 4;
|
||||
asm volatile(
|
||||
"mov x10, %[src_ptr1]\n"
|
||||
"mov x11, %[dst_ptr1]\n"
|
||||
"mov x12, %[strid_row]\n"
|
||||
"ld1 {v8.4s, v9.4s}, [x10], x12\n"
|
||||
"ld1 {v10.4s, v11.4s}, [x10], x12\n"
|
||||
"ld1 {v12.4s, v13.4s}, [x10], x12\n"
|
||||
"ld1 {v14.4s, v15.4s}, [x10], x12\n"
|
||||
"ld1 {v16.4s, v17.4s}, [x10], x12\n"
|
||||
"ld1 {v18.4s, v19.4s}, [x10], x12\n"
|
||||
"ld1 {v20.4s, v21.4s}, [x10], x12\n"
|
||||
"ld1 {v22.4s, v23.4s}, [x10], x12\n"
|
||||
|
||||
"fcvtn v0.4h, v8.4s\n"
|
||||
"fcvtn2 v0.8h, v9.4s\n"
|
||||
"fcvtn v1.4h, v10.4s\n"
|
||||
"fcvtn2 v1.8h, v11.4s\n"
|
||||
"fcvtn v2.4h, v12.4s\n"
|
||||
"fcvtn2 v2.8h, v13.4s\n"
|
||||
"fcvtn v3.4h, v14.4s\n"
|
||||
"fcvtn2 v3.8h, v15.4s\n"
|
||||
"fcvtn v4.4h, v16.4s\n"
|
||||
"fcvtn2 v4.8h, v17.4s\n"
|
||||
"fcvtn v5.4h, v18.4s\n"
|
||||
"fcvtn2 v5.8h, v19.4s\n"
|
||||
"fcvtn v6.4h, v20.4s\n"
|
||||
"fcvtn2 v6.8h, v21.4s\n"
|
||||
"fcvtn v7.4h, v22.4s\n"
|
||||
"fcvtn2 v7.8h, v23.4s\n"
|
||||
|
||||
"zip1 v8.8h, v0.8h, v1.8h\n"
|
||||
"zip1 v9.8h, v2.8h, v3.8h\n"
|
||||
"zip1 v10.8h, v4.8h, v5.8h\n"
|
||||
"zip1 v11.8h, v6.8h, v7.8h\n"
|
||||
|
||||
"trn1 v12.4s, v8.4s, v9.4s\n"
|
||||
"trn1 v14.4s, v10.4s, v11.4s\n"
|
||||
"trn2 v13.4s, v8.4s, v9.4s\n"
|
||||
"trn2 v15.4s, v10.4s, v11.4s\n"
|
||||
|
||||
"trn1 v16.2d, v12.2d, v14.2d\n"
|
||||
"trn2 v18.2d, v12.2d, v14.2d\n"
|
||||
"trn1 v17.2d, v13.2d, v15.2d\n"
|
||||
"trn2 v19.2d, v13.2d, v15.2d\n"
|
||||
|
||||
"zip2 v8.8h, v0.8h, v1.8h\n"
|
||||
"zip2 v9.8h, v2.8h, v3.8h\n"
|
||||
"zip2 v10.8h, v4.8h, v5.8h\n"
|
||||
"zip2 v11.8h, v6.8h, v7.8h\n"
|
||||
|
||||
"trn1 v12.4s, v8.4s, v9.4s\n"
|
||||
"trn1 v14.4s, v10.4s, v11.4s\n"
|
||||
"trn2 v13.4s, v8.4s, v9.4s\n"
|
||||
"trn2 v15.4s, v10.4s, v11.4s\n"
|
||||
|
||||
"trn1 v20.2d, v12.2d, v14.2d\n"
|
||||
"trn2 v22.2d, v12.2d, v14.2d\n"
|
||||
"trn1 v21.2d, v13.2d, v15.2d\n"
|
||||
"trn2 v23.2d, v13.2d, v15.2d\n"
|
||||
|
||||
"st1 {v16.8h}, [x11], #16\n"
|
||||
"st1 {v17.8h}, [x11], #16\n"
|
||||
"st1 {v18.8h}, [x11], #16\n"
|
||||
"st1 {v19.8h}, [x11], #16\n"
|
||||
"st1 {v20.8h}, [x11], #16\n"
|
||||
"st1 {v21.8h}, [x11], #16\n"
|
||||
"st1 {v22.8h}, [x11], #16\n"
|
||||
"st1 {v23.8h}, [x11], #16\n"
|
||||
:
|
||||
: [ dst_ptr1 ] "r"(dst_ptr1), [ src_ptr1 ] "r"(src_ptr1), [ strid_row ] "r"(strid_row)
|
||||
: "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13",
|
||||
"v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
|
||||
#else
|
||||
for (int tr = 0; tr < C8NUM; ++tr) {
|
||||
for (int tc = 0; tc < C8NUM; ++tc) {
|
||||
dst_ptr1[tr * C8NUM + tc] = (float16_t)(src_ptr1[tc * row + tr]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
for (; ri < row; ++ri) {
|
||||
const float *src_ptr1 = src + ci * row;
|
||||
float16_t *dst_ptr1 = dst_ptr + ci * row;
|
||||
for (int tc = 0; tc < C8NUM; ++tc) {
|
||||
dst_ptr1[ri * C8NUM + tc] = (float16_t)(src_ptr1[tc * row + ri]);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int r = 0; r < row; r++) {
|
||||
for (int tc = ci; tc < col; tc++) {
|
||||
int cd8 = tc / C8NUM;
|
||||
int cm8 = tc % C8NUM;
|
||||
dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = (float16_t)(src[tc * row + r]);
|
||||
}
|
||||
}
|
||||
Col2Row8SrcFromFp32(src_ptr, dst_ptr, row, col);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
@ -274,126 +285,129 @@ void MatVecMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const f
|
|||
MatVecMulFp16Neon64(a, b, c, bias, (int)act_type, depth, col);
|
||||
}
|
||||
|
||||
static void Row2Col16Block16(const float16_t *src_ptr, float16_t *dst_ptr, size_t col) {
|
||||
size_t stride = col * 2;
|
||||
asm volatile(
|
||||
"mov x10, %[src_c]\n"
|
||||
"mov x11, %[dst_c]\n"
|
||||
|
||||
"ld1 {v0.8h}, [x10], %[stride]\n"
|
||||
"ld1 {v1.8h}, [x10], %[stride]\n"
|
||||
"ld1 {v2.8h}, [x10], %[stride]\n"
|
||||
"ld1 {v3.8h}, [x10], %[stride]\n"
|
||||
"ld1 {v4.8h}, [x10], %[stride]\n"
|
||||
"ld1 {v5.8h}, [x10], %[stride]\n"
|
||||
"ld1 {v6.8h}, [x10], %[stride]\n"
|
||||
"ld1 {v7.8h}, [x10], %[stride]\n"
|
||||
|
||||
"zip1 v16.8h, v0.8h, v1.8h\n"
|
||||
"zip1 v17.8h, v2.8h, v3.8h\n"
|
||||
"zip1 v18.8h, v4.8h, v5.8h\n"
|
||||
"zip1 v19.8h, v6.8h, v7.8h\n"
|
||||
|
||||
"ld1 {v8.8h}, [x10], %[stride]\n"
|
||||
"ld1 {v9.8h}, [x10], %[stride]\n"
|
||||
"ld1 {v10.8h}, [x10], %[stride]\n"
|
||||
"ld1 {v11.8h}, [x10], %[stride]\n"
|
||||
"ld1 {v12.8h}, [x10], %[stride]\n"
|
||||
"ld1 {v13.8h}, [x10], %[stride]\n"
|
||||
"ld1 {v14.8h}, [x10], %[stride]\n"
|
||||
"ld1 {v15.8h}, [x10], %[stride]\n"
|
||||
|
||||
"trn1 v20.4s, v16.4s, v17.4s\n"
|
||||
"trn2 v21.4s, v16.4s, v17.4s\n"
|
||||
"trn1 v22.4s, v18.4s, v19.4s\n"
|
||||
"trn2 v23.4s, v18.4s, v19.4s\n"
|
||||
|
||||
"trn1 v24.2d, v20.2d, v22.2d\n"
|
||||
"trn2 v25.2d, v20.2d, v22.2d\n"
|
||||
"trn1 v26.2d, v21.2d, v23.2d\n"
|
||||
"trn2 v27.2d, v21.2d, v23.2d\n"
|
||||
|
||||
"zip1 v16.8h, v8.8h, v9.8h\n"
|
||||
"zip1 v17.8h, v10.8h, v11.8h\n"
|
||||
"zip1 v18.8h, v12.8h, v13.8h\n"
|
||||
"zip1 v19.8h, v14.8h, v15.8h\n"
|
||||
|
||||
"trn1 v20.4s, v16.4s, v17.4s\n"
|
||||
"trn2 v21.4s, v16.4s, v17.4s\n"
|
||||
"trn1 v22.4s, v18.4s, v19.4s\n"
|
||||
"trn2 v23.4s, v18.4s, v19.4s\n"
|
||||
|
||||
"trn1 v28.2d, v20.2d, v22.2d\n"
|
||||
"trn2 v29.2d, v20.2d, v22.2d\n"
|
||||
"trn1 v30.2d, v21.2d, v23.2d\n"
|
||||
"trn2 v31.2d, v21.2d, v23.2d\n"
|
||||
|
||||
"st1 {v24.8h}, [x11], #16\n"
|
||||
"st1 {v28.8h}, [x11], #16\n"
|
||||
"st1 {v26.8h}, [x11], #16\n"
|
||||
"st1 {v30.8h}, [x11], #16\n"
|
||||
"st1 {v25.8h}, [x11], #16\n"
|
||||
"st1 {v29.8h}, [x11], #16\n"
|
||||
"st1 {v27.8h}, [x11], #16\n"
|
||||
"st1 {v31.8h}, [x11], #16\n"
|
||||
|
||||
"zip2 v16.8h, v0.8h, v1.8h\n"
|
||||
"zip2 v17.8h, v2.8h, v3.8h\n"
|
||||
"zip2 v18.8h, v4.8h, v5.8h\n"
|
||||
"zip2 v19.8h, v6.8h, v7.8h\n"
|
||||
|
||||
"trn1 v20.4s, v16.4s, v17.4s\n"
|
||||
"trn2 v21.4s, v16.4s, v17.4s\n"
|
||||
"trn1 v22.4s, v18.4s, v19.4s\n"
|
||||
"trn2 v23.4s, v18.4s, v19.4s\n"
|
||||
|
||||
"trn1 v24.2d, v20.2d, v22.2d\n"
|
||||
"trn2 v25.2d, v20.2d, v22.2d\n"
|
||||
"trn1 v26.2d, v21.2d, v23.2d\n"
|
||||
"trn2 v27.2d, v21.2d, v23.2d\n"
|
||||
|
||||
"zip2 v16.8h, v8.8h, v9.8h\n"
|
||||
"zip2 v17.8h, v10.8h, v11.8h\n"
|
||||
"zip2 v18.8h, v12.8h, v13.8h\n"
|
||||
"zip2 v19.8h, v14.8h, v15.8h\n"
|
||||
|
||||
"trn1 v20.4s, v16.4s, v17.4s\n"
|
||||
"trn2 v21.4s, v16.4s, v17.4s\n"
|
||||
"trn1 v22.4s, v18.4s, v19.4s\n"
|
||||
"trn2 v23.4s, v18.4s, v19.4s\n"
|
||||
|
||||
"trn1 v28.2d, v20.2d, v22.2d\n"
|
||||
"trn2 v29.2d, v20.2d, v22.2d\n"
|
||||
"trn1 v30.2d, v21.2d, v23.2d\n"
|
||||
"trn2 v31.2d, v21.2d, v23.2d\n"
|
||||
|
||||
"st1 {v24.8h}, [x11], #16\n"
|
||||
"st1 {v28.8h}, [x11], #16\n"
|
||||
"st1 {v26.8h}, [x11], #16\n"
|
||||
"st1 {v30.8h}, [x11], #16\n"
|
||||
"st1 {v25.8h}, [x11], #16\n"
|
||||
"st1 {v29.8h}, [x11], #16\n"
|
||||
"st1 {v27.8h}, [x11], #16\n"
|
||||
"st1 {v31.8h}, [x11], #16\n"
|
||||
:
|
||||
: [ dst_c ] "r"(dst_ptr), [ src_c ] "r"(src_ptr), [ stride ] "r"(stride)
|
||||
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
|
||||
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
|
||||
"v31");
|
||||
}
|
||||
|
||||
void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) {
|
||||
size_t row_up_16 = UP_ROUND(row, C16NUM);
|
||||
size_t row16 = row / C16NUM * C16NUM;
|
||||
size_t col8 = col / C8NUM * C8NUM;
|
||||
const float16_t *src_r = src_ptr;
|
||||
float16_t *dst_r = dst_ptr;
|
||||
|
||||
size_t ri = 0;
|
||||
// find 16 block unit
|
||||
for (; ri < row16; ri += C16NUM) {
|
||||
size_t ci = 0;
|
||||
for (; ci < col8; ci += C8NUM) {
|
||||
const float16_t *src_c = src_r + ci;
|
||||
float16_t *dst_c = dst_r + ci * C16NUM;
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
size_t stride = col * 2;
|
||||
asm volatile(
|
||||
"mov x10, %[src_c]\n"
|
||||
"mov x11, %[dst_c]\n"
|
||||
|
||||
"ld1 {v0.8h}, [x10], %[stride]\n"
|
||||
"ld1 {v1.8h}, [x10], %[stride]\n"
|
||||
"ld1 {v2.8h}, [x10], %[stride]\n"
|
||||
"ld1 {v3.8h}, [x10], %[stride]\n"
|
||||
"ld1 {v4.8h}, [x10], %[stride]\n"
|
||||
"ld1 {v5.8h}, [x10], %[stride]\n"
|
||||
"ld1 {v6.8h}, [x10], %[stride]\n"
|
||||
"ld1 {v7.8h}, [x10], %[stride]\n"
|
||||
|
||||
"zip1 v16.8h, v0.8h, v1.8h\n"
|
||||
"zip1 v17.8h, v2.8h, v3.8h\n"
|
||||
"zip1 v18.8h, v4.8h, v5.8h\n"
|
||||
"zip1 v19.8h, v6.8h, v7.8h\n"
|
||||
|
||||
"ld1 {v8.8h}, [x10], %[stride]\n"
|
||||
"ld1 {v9.8h}, [x10], %[stride]\n"
|
||||
"ld1 {v10.8h}, [x10], %[stride]\n"
|
||||
"ld1 {v11.8h}, [x10], %[stride]\n"
|
||||
"ld1 {v12.8h}, [x10], %[stride]\n"
|
||||
"ld1 {v13.8h}, [x10], %[stride]\n"
|
||||
"ld1 {v14.8h}, [x10], %[stride]\n"
|
||||
"ld1 {v15.8h}, [x10], %[stride]\n"
|
||||
|
||||
"trn1 v20.4s, v16.4s, v17.4s\n"
|
||||
"trn2 v21.4s, v16.4s, v17.4s\n"
|
||||
"trn1 v22.4s, v18.4s, v19.4s\n"
|
||||
"trn2 v23.4s, v18.4s, v19.4s\n"
|
||||
|
||||
"trn1 v24.2d, v20.2d, v22.2d\n"
|
||||
"trn2 v25.2d, v20.2d, v22.2d\n"
|
||||
"trn1 v26.2d, v21.2d, v23.2d\n"
|
||||
"trn2 v27.2d, v21.2d, v23.2d\n"
|
||||
|
||||
"zip1 v16.8h, v8.8h, v9.8h\n"
|
||||
"zip1 v17.8h, v10.8h, v11.8h\n"
|
||||
"zip1 v18.8h, v12.8h, v13.8h\n"
|
||||
"zip1 v19.8h, v14.8h, v15.8h\n"
|
||||
|
||||
"trn1 v20.4s, v16.4s, v17.4s\n"
|
||||
"trn2 v21.4s, v16.4s, v17.4s\n"
|
||||
"trn1 v22.4s, v18.4s, v19.4s\n"
|
||||
"trn2 v23.4s, v18.4s, v19.4s\n"
|
||||
|
||||
"trn1 v28.2d, v20.2d, v22.2d\n"
|
||||
"trn2 v29.2d, v20.2d, v22.2d\n"
|
||||
"trn1 v30.2d, v21.2d, v23.2d\n"
|
||||
"trn2 v31.2d, v21.2d, v23.2d\n"
|
||||
|
||||
"st1 {v24.8h}, [x11], #16\n"
|
||||
"st1 {v28.8h}, [x11], #16\n"
|
||||
"st1 {v26.8h}, [x11], #16\n"
|
||||
"st1 {v30.8h}, [x11], #16\n"
|
||||
"st1 {v25.8h}, [x11], #16\n"
|
||||
"st1 {v29.8h}, [x11], #16\n"
|
||||
"st1 {v27.8h}, [x11], #16\n"
|
||||
"st1 {v31.8h}, [x11], #16\n"
|
||||
|
||||
"zip2 v16.8h, v0.8h, v1.8h\n"
|
||||
"zip2 v17.8h, v2.8h, v3.8h\n"
|
||||
"zip2 v18.8h, v4.8h, v5.8h\n"
|
||||
"zip2 v19.8h, v6.8h, v7.8h\n"
|
||||
|
||||
"trn1 v20.4s, v16.4s, v17.4s\n"
|
||||
"trn2 v21.4s, v16.4s, v17.4s\n"
|
||||
"trn1 v22.4s, v18.4s, v19.4s\n"
|
||||
"trn2 v23.4s, v18.4s, v19.4s\n"
|
||||
|
||||
"trn1 v24.2d, v20.2d, v22.2d\n"
|
||||
"trn2 v25.2d, v20.2d, v22.2d\n"
|
||||
"trn1 v26.2d, v21.2d, v23.2d\n"
|
||||
"trn2 v27.2d, v21.2d, v23.2d\n"
|
||||
|
||||
"zip2 v16.8h, v8.8h, v9.8h\n"
|
||||
"zip2 v17.8h, v10.8h, v11.8h\n"
|
||||
"zip2 v18.8h, v12.8h, v13.8h\n"
|
||||
"zip2 v19.8h, v14.8h, v15.8h\n"
|
||||
|
||||
"trn1 v20.4s, v16.4s, v17.4s\n"
|
||||
"trn2 v21.4s, v16.4s, v17.4s\n"
|
||||
"trn1 v22.4s, v18.4s, v19.4s\n"
|
||||
"trn2 v23.4s, v18.4s, v19.4s\n"
|
||||
|
||||
"trn1 v28.2d, v20.2d, v22.2d\n"
|
||||
"trn2 v29.2d, v20.2d, v22.2d\n"
|
||||
"trn1 v30.2d, v21.2d, v23.2d\n"
|
||||
"trn2 v31.2d, v21.2d, v23.2d\n"
|
||||
|
||||
"st1 {v24.8h}, [x11], #16\n"
|
||||
"st1 {v28.8h}, [x11], #16\n"
|
||||
"st1 {v26.8h}, [x11], #16\n"
|
||||
"st1 {v30.8h}, [x11], #16\n"
|
||||
"st1 {v25.8h}, [x11], #16\n"
|
||||
"st1 {v29.8h}, [x11], #16\n"
|
||||
"st1 {v27.8h}, [x11], #16\n"
|
||||
"st1 {v31.8h}, [x11], #16\n"
|
||||
:
|
||||
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
|
||||
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
|
||||
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
|
||||
"v30", "v31");
|
||||
Row2Col16Block16(src_c, dst_c, col);
|
||||
#else
|
||||
for (int tr = 0; tr < C16NUM; tr++) {
|
||||
for (int tc = 0; tc < C8NUM; tc++) {
|
||||
|
@ -413,7 +427,7 @@ void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, si
|
|||
dst_r += C16NUM * col;
|
||||
}
|
||||
for (; ri < row; ri++) {
|
||||
for (size_t i = 0; i < col; i++) {
|
||||
for (size_t i = 0; i < col; ++i) {
|
||||
dst_r[i * C16NUM] = src_r[i];
|
||||
}
|
||||
src_r += col;
|
||||
|
|
|
@ -40,6 +40,9 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_
|
|||
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) {
|
||||
int start_index = thread_id * cal_num;
|
||||
int real_cal_num = (output_count - start_index) < cal_num ? (output_count - start_index) : cal_num;
|
||||
if (real_cal_num <= 0) {
|
||||
return;
|
||||
}
|
||||
float *gemm_input = packed_input + task_id * deep * cal_num;
|
||||
float *col_major_gemm_input = col_major_input + task_id * deep * cal_num;
|
||||
size_t packed_input_size = deep * cal_num * sizeof(float);
|
||||
|
|
|
@ -56,6 +56,9 @@ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const
|
|||
int out_tile_index = thread_id * tile_num;
|
||||
int cal_num = output_count - out_tile_index;
|
||||
cal_num = cal_num > tile_num ? tile_num : cal_num;
|
||||
if (cal_num <= 0) {
|
||||
return;
|
||||
}
|
||||
WinogradInputTransform(input_data + in_batch_offset, trans_input + task_id * trans_input_offset,
|
||||
tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param,
|
||||
in_func);
|
||||
|
|
|
@ -36,7 +36,6 @@ ConvolutionBaseCPUKernel::~ConvolutionBaseCPUKernel() {
|
|||
}
|
||||
|
||||
void ConvolutionBaseCPUKernel::FreeQuantParam() {
|
||||
ConvQuantArg *conv_quant_arg_ = &conv_param_->conv_quant_arg_;
|
||||
if (conv_quant_arg_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -44,7 +44,10 @@ class ConvolutionDelegateFP16CPUKernel : public LiteKernel {
|
|||
void FreeCopiedData();
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override { return fp16_conv_kernel_->Run(); }
|
||||
int Run() override {
|
||||
fp16_conv_kernel_->set_name(name_);
|
||||
return fp16_conv_kernel_->Run();
|
||||
}
|
||||
|
||||
private:
|
||||
uint8_t need_free_ = 0b00;
|
||||
|
|
|
@ -102,6 +102,13 @@ int ConvolutionFP16CPUKernel::Init() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
void ConvolutionFP16CPUKernel::AdjustNumberOfThread() {
|
||||
auto out_tensor = out_tensors_.front();
|
||||
int out_plane = out_tensor->Height() * out_tensor->Width();
|
||||
thread_count_ = MSMIN(ctx_->thread_num_, UP_DIV(out_plane, C16NUM));
|
||||
conv_param_->thread_num_ = thread_count_;
|
||||
}
|
||||
|
||||
int ConvolutionFP16CPUKernel::ReSize() {
|
||||
auto ret = ConvolutionBaseCPUKernel::CheckResizeValid();
|
||||
if (ret != RET_OK) {
|
||||
|
|
|
@ -44,6 +44,7 @@ class ConvolutionFP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
|
|||
int RunImpl(int task_id);
|
||||
int InitWeightBias();
|
||||
int InitTmpBuffer();
|
||||
void AdjustNumberOfThread();
|
||||
|
||||
private:
|
||||
void FreeTmpBuffer() {
|
||||
|
|
|
@ -108,7 +108,6 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() {
|
|||
int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() {
|
||||
const int cal_num = 16;
|
||||
int channel_out = conv_param_->output_channel_;
|
||||
int oc8 = UP_DIV(channel_out, C8NUM);
|
||||
|
||||
size_t tile_buffer_size =
|
||||
thread_count_ * cal_num * input_unit_ * input_unit_ * conv_param_->input_channel_ * sizeof(float16_t);
|
||||
|
@ -118,8 +117,8 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
|
||||
gemm_out_ = reinterpret_cast<float16_t *>(
|
||||
ctx_->allocator->Malloc(thread_count_ * cal_num * input_unit_ * input_unit_ * oc8 * C8NUM * sizeof(float16_t)));
|
||||
gemm_out_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(
|
||||
thread_count_ * cal_num * input_unit_ * input_unit_ * UP_ROUND(channel_out, C8NUM) * sizeof(float16_t)));
|
||||
if (gemm_out_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc gemm_out_ failed.";
|
||||
return RET_ERROR;
|
||||
|
@ -174,6 +173,13 @@ int ConvolutionWinogradFP16CPUKernel::Init() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
void ConvolutionWinogradFP16CPUKernel::AdjustNumberOfThread() {
|
||||
auto out_tensor = out_tensors_.front();
|
||||
int cal_plane = UP_DIV(out_tensor->Height(), output_unit_) * UP_DIV(out_tensor->Width(), output_unit_);
|
||||
thread_count_ = MSMIN(ctx_->thread_num_, UP_DIV(cal_plane, C8NUM));
|
||||
conv_param_->thread_num_ = thread_count_;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradFP16CPUKernel::ReSize() {
|
||||
auto ret = ConvolutionBaseCPUKernel::CheckResizeValid();
|
||||
if (ret != RET_OK) {
|
||||
|
@ -190,6 +196,7 @@ int ConvolutionWinogradFP16CPUKernel::ReSize() {
|
|||
MS_LOG(ERROR) << "ConfigInputOutput failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
AdjustNumberOfThread();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -52,6 +52,7 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
|
|||
int InitTmpBuffer();
|
||||
int ConfigInputOutput();
|
||||
int WinogradFilterTransformFp16(const float16_t *weight_data, float *matrix_g, float *matrix_gt, int oc_block);
|
||||
void AdjustNumberOfThread();
|
||||
|
||||
private:
|
||||
void FreeTmpBuffer() {
|
||||
|
|
|
@ -48,16 +48,9 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() {
|
|||
conv_param_->input_channel_ = in_channel;
|
||||
conv_param_->output_channel_ = out_channel;
|
||||
|
||||
int oc4 = UP_DIV(out_channel, C4NUM);
|
||||
#ifdef ENABLE_AVX
|
||||
const int oc_block = C16NUM;
|
||||
#else
|
||||
const int oc_block = C8NUM;
|
||||
#endif
|
||||
int oc_block_num = UP_DIV(out_channel, oc_block);
|
||||
|
||||
// set data
|
||||
auto trans_matrix_data_size = input_unit_ * input_unit_ * in_channel * oc_block_num * oc_block * sizeof(float);
|
||||
auto trans_matrix_data_size =
|
||||
input_unit_ * input_unit_ * in_channel * UP_ROUND(out_channel, oc_block_) * sizeof(float);
|
||||
if (trans_weight_ == nullptr) {
|
||||
trans_weight_ = reinterpret_cast<float *>(malloc(trans_matrix_data_size));
|
||||
if (trans_weight_ == nullptr) {
|
||||
|
@ -83,14 +76,15 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() {
|
|||
MS_LOG(ERROR) << "get matrix g from CookToomFilter failed.";
|
||||
return ret;
|
||||
}
|
||||
ret = WinogradFilterTransform(origin_weight_, matrix_g, matrix_gt, oc_block);
|
||||
ret = WinogradFilterTransform(origin_weight_, matrix_g, matrix_gt, oc_block_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "winograd filter transform failed.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
// init bias
|
||||
size_t new_bias_size = oc4 * C4NUM * sizeof(float);
|
||||
size_t new_bias_size = UP_ROUND(out_channel, C4NUM) * sizeof(float);
|
||||
bias_data_ = malloc(new_bias_size);
|
||||
if (bias_data_ == nullptr) {
|
||||
bias_data_ = reinterpret_cast<float *>(malloc(new_bias_size));
|
||||
if (bias_data_ == nullptr) {
|
||||
|
@ -98,31 +92,30 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() {
|
|||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
}
|
||||
memset(bias_data_, 0, new_bias_size);
|
||||
if (in_tensors_.size() == kInputSize2) {
|
||||
memcpy(bias_data_, origin_bias_, out_channel * sizeof(float));
|
||||
size_t origin_size = out_channel * sizeof(float);
|
||||
memcpy(bias_data_, origin_bias_, origin_size);
|
||||
memset(reinterpret_cast<float *>(bias_data_) + out_channel, 0, new_bias_size - origin_size);
|
||||
} else {
|
||||
MS_ASSERT(in_tensors_.size() == kInputSize1);
|
||||
memset(bias_data_, 0, new_bias_size);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradCPUKernel::InitTmpBuffer() {
|
||||
int channel_out = conv_param_->output_channel_;
|
||||
int oc8 = UP_DIV(channel_out, C8NUM);
|
||||
int tile_num = C12NUM;
|
||||
MS_ASSERT(ctx_->allocator != nullptr);
|
||||
|
||||
size_t tile_buffer_size =
|
||||
thread_count_ * tile_num * input_unit_ * input_unit_ * conv_param_->input_channel_ * sizeof(float);
|
||||
thread_count_ * tile_num_ * input_unit_ * input_unit_ * conv_param_->input_channel_ * sizeof(float);
|
||||
trans_input_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(tile_buffer_size));
|
||||
if (trans_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc trans_input_ failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
|
||||
int oc8 = UP_ROUND(conv_param_->output_channel_, C8NUM);
|
||||
gemm_out_ = reinterpret_cast<float *>(
|
||||
ctx_->allocator->Malloc(thread_count_ * tile_num * input_unit_ * input_unit_ * oc8 * C8NUM * sizeof(float)));
|
||||
ctx_->allocator->Malloc(thread_count_ * tile_num_ * input_unit_ * input_unit_ * oc8 * sizeof(float)));
|
||||
if (gemm_out_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc gemm_out_ failed.";
|
||||
return RET_ERROR;
|
||||
|
@ -136,7 +129,7 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() {
|
|||
}
|
||||
|
||||
col_buffer_ = reinterpret_cast<float *>(
|
||||
ctx_->allocator->Malloc(thread_count_ * tile_num * conv_param_->input_channel_ * sizeof(float)));
|
||||
ctx_->allocator->Malloc(thread_count_ * tile_num_ * conv_param_->input_channel_ * sizeof(float)));
|
||||
if (col_buffer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc col_buffer_ failed.";
|
||||
return RET_ERROR;
|
||||
|
@ -164,10 +157,17 @@ int ConvolutionWinogradCPUKernel::ConfigInputOutput() {
|
|||
}
|
||||
|
||||
int ConvolutionWinogradCPUKernel::Init() {
|
||||
tile_num_ = C12NUM;
|
||||
#ifdef ENABLE_AVX
|
||||
oc_block_ = C16NUM;
|
||||
#else
|
||||
oc_block_ = C8NUM;
|
||||
#endif
|
||||
kernel_unit_ = conv_param_->kernel_h_;
|
||||
input_unit_ = output_unit_ + kernel_unit_ - 1;
|
||||
conv_param_->input_unit_ = input_unit_;
|
||||
conv_param_->output_unit_ = output_unit_;
|
||||
|
||||
auto ret = InitWeightBias();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init weight bias failed.";
|
||||
|
@ -197,8 +197,8 @@ int ConvolutionWinogradCPUKernel::ReSize() {
|
|||
|
||||
int ConvolutionWinogradCPUKernel::RunImpl(int task_id) {
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
auto ori_input_data = reinterpret_cast<float *>(input_tensor->MutableData());
|
||||
auto output_data = reinterpret_cast<float *>(out_tensors_.front()->MutableData());
|
||||
auto ori_input_data = reinterpret_cast<float *>(input_tensor->data_c());
|
||||
auto output_data = reinterpret_cast<float *>(out_tensors_.front()->data_c());
|
||||
ConvWinogardFp32(ori_input_data, trans_weight_, reinterpret_cast<const float *>(bias_data_), output_data,
|
||||
tmp_buffer_address_list_, task_id, conv_param_, in_func_, out_func_);
|
||||
return RET_OK;
|
||||
|
|
|
@ -70,9 +70,11 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel {
|
|||
col_buffer_ = nullptr;
|
||||
}
|
||||
}
|
||||
int kernel_unit_;
|
||||
int input_unit_;
|
||||
int kernel_unit_{0};
|
||||
int input_unit_{0};
|
||||
int output_unit_;
|
||||
int oc_block_{0};
|
||||
int tile_num_{0};
|
||||
float *origin_weight_; // do not free
|
||||
float *origin_bias_; // do not free
|
||||
float *tmp_data_ = nullptr;
|
||||
|
|
Loading…
Reference in New Issue