optimize conv1x1

This commit is contained in:
sunsuodong 2021-06-20 16:29:16 +08:00
parent 35c1f14cf3
commit 71f58669bb
6 changed files with 1850 additions and 21 deletions

File diff suppressed because it is too large Load Diff

View File

@ -350,6 +350,29 @@ void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, cons
}
}
#ifdef ENABLE_DEBUG
void MatMul12x16Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
int deep, int row, int col, size_t stride, size_t out_type) {
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int r12div = r / C12NUM, r12mod = r % C12NUM;
int c16div = c / C16NUM, c16mod = c % C16NUM;
size_t index = r * stride + c;
float16_t value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r12div * deep * C12NUM + d * C12NUM + r12mod;
size_t bi = c16div * deep * C16NUM + d * C16NUM + c16mod;
value = value + a[ai] * b[bi];
}
ADD_BIAS(value, bias, c)
DO_RELU(value, act_type)
DO_RELU6(value, act_type)
dst[index] = value;
}
}
}
#endif
void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type,
int depth, int row, int col, int stride, int out_type) {
if (out_type == OutType_C8) {
@ -567,7 +590,9 @@ void RowMajor2Col12MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, si
for (; ci < col8; ci += C8NUM) {
const float16_t *src_c = src_r + ci;
float16_t *dst_c = dst_r + ci * C12NUM;
#ifdef ENABLE_ARM82_A32
#ifdef ENABLE_ARM64
Transpose12x8ARM64Fp16(src_c, dst_c, col * sizeof(float16_t), 24);
#elif ENABLE_ARM82_A32
Transpose12x8A32Fp16(src_c, dst_c, col * sizeof(float16_t), 24);
#else
for (int tr = 0; tr < C12NUM; tr++) {

View File

@ -45,6 +45,8 @@ void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, cons
int deep, int row, int col, int stride, int write_mode);
#ifdef ENABLE_ARM64
void MatMul12x16Fp16Opt(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
int deep, int row, int col, size_t stride, size_t out_type);
void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
size_t depth, size_t row, size_t col, size_t stride, bool write_nhwc);
@ -65,6 +67,9 @@ void MatVecMulA32NeonFp16(const float16_t *a, const float16_t *b, float16_t *c,
int depth, int col);
#endif
void MatMul12x16Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
int deep, int row, int col, size_t stride, size_t out_type);
void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type,
int depth, int row, int col, int stride, int out_type);

View File

@ -640,6 +640,109 @@ inline void Transpose12x8A32Fp16(const float16_t *src_c, float16_t *dst_c, size_
#endif
#ifdef ENABLE_ARM64
void Transpose12x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride) {
#ifdef ENABLE_DEBUG
for (int tr = 0; tr < C12NUM; tr++) {
for (int tc = 0; tc < C8NUM; tc++) {
dst_ptr[tc * C12NUM + tr] = src_ptr[tr * col + tc];
}
}
#else
asm volatile(
"mov x10, %[src_ptr]\n"
"mov x11, %[dst_ptr]\n"
"ld1 {v0.8h}, [x10], %[src_stride]\n"
"ld1 {v1.8h}, [x10], %[src_stride]\n"
"ld1 {v2.8h}, [x10], %[src_stride]\n"
"ld1 {v3.8h}, [x10], %[src_stride]\n"
"ld1 {v4.8h}, [x10], %[src_stride]\n"
"ld1 {v5.8h}, [x10], %[src_stride]\n"
"ld1 {v6.8h}, [x10], %[src_stride]\n"
"ld1 {v7.8h}, [x10], %[src_stride]\n"
"zip1 v16.8h, v0.8h, v1.8h\n"
"zip1 v17.8h, v2.8h, v3.8h\n"
"zip1 v18.8h, v4.8h, v5.8h\n"
"zip1 v19.8h, v6.8h, v7.8h\n"
"ld1 {v8.8h}, [x10], %[src_stride]\n"
"ld1 {v9.8h}, [x10], %[src_stride]\n"
"ld1 {v10.8h}, [x10], %[src_stride]\n"
"ld1 {v11.8h}, [x10], %[src_stride]\n"
"trn1 v20.4s, v16.4s, v17.4s\n"
"trn2 v21.4s, v16.4s, v17.4s\n"
"trn1 v22.4s, v18.4s, v19.4s\n"
"trn2 v23.4s, v18.4s, v19.4s\n"
"trn1 v24.2d, v20.2d, v22.2d\n"
"trn2 v25.2d, v20.2d, v22.2d\n"
"trn1 v26.2d, v21.2d, v23.2d\n"
"trn2 v27.2d, v21.2d, v23.2d\n"
"zip1 v16.8h, v8.8h, v9.8h\n"
"zip1 v17.8h, v10.8h, v11.8h\n"
"trn1 v20.4s, v16.4s, v17.4s\n"
"trn2 v21.4s, v16.4s, v17.4s\n"
"trn1 v28.2d, v20.2d, v20.2d\n"
"trn2 v29.2d, v20.2d, v20.2d\n"
"trn1 v30.2d, v21.2d, v21.2d\n"
"trn2 v31.2d, v21.2d, v21.2d\n"
"add x10, x11, #16\n"
"st1 {v24.8h}, [x11], %[dst_stride]\n"
"st1 {v28.4h}, [x10], %[dst_stride]\n"
"st1 {v26.8h}, [x11], %[dst_stride]\n"
"st1 {v30.4h}, [x10], %[dst_stride]\n"
"st1 {v25.8h}, [x11], %[dst_stride]\n"
"st1 {v29.4h}, [x10], %[dst_stride]\n"
"st1 {v27.8h}, [x11], %[dst_stride]\n"
"st1 {v31.4h}, [x10], %[dst_stride]\n"
"zip2 v16.8h, v0.8h, v1.8h\n"
"zip2 v17.8h, v2.8h, v3.8h\n"
"zip2 v18.8h, v4.8h, v5.8h\n"
"zip2 v19.8h, v6.8h, v7.8h\n"
"trn1 v20.4s, v16.4s, v17.4s\n"
"trn2 v21.4s, v16.4s, v17.4s\n"
"trn1 v22.4s, v18.4s, v19.4s\n"
"trn2 v23.4s, v18.4s, v19.4s\n"
"trn1 v24.2d, v20.2d, v22.2d\n"
"trn2 v25.2d, v20.2d, v22.2d\n"
"trn1 v26.2d, v21.2d, v23.2d\n"
"trn2 v27.2d, v21.2d, v23.2d\n"
"zip2 v16.8h, v8.8h, v9.8h\n"
"zip2 v17.8h, v10.8h, v11.8h\n"
"trn1 v20.4s, v16.4s, v17.4s\n"
"trn2 v21.4s, v16.4s, v17.4s\n"
"trn1 v28.2d, v20.2d, v20.2d\n"
"trn2 v29.2d, v20.2d, v20.2d\n"
"trn1 v30.2d, v21.2d, v21.2d\n"
"trn2 v31.2d, v21.2d, v21.2d\n"
"st1 {v24.8h}, [x11], %[dst_stride]\n"
"st1 {v28.4h}, [x10], %[dst_stride]\n"
"st1 {v26.8h}, [x11], %[dst_stride]\n"
"st1 {v30.4h}, [x10], %[dst_stride]\n"
"st1 {v25.8h}, [x11], %[dst_stride]\n"
"st1 {v29.4h}, [x10], %[dst_stride]\n"
"st1 {v27.8h}, [x11], %[dst_stride]\n"
"st1 {v31.4h}, [x10], %[dst_stride]\n"
:
: [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ src_stride ] "r"(src_stride), [ dst_stride ] "r"(dst_stride)
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v16", "v17", "v18",
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
#endif
}
inline void Transpose16x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride) {
asm volatile(
"mov x10, %[src_ptr]\n"

View File

@ -78,6 +78,7 @@ void Transpose12x8A32Fp16(const float16_t *src, float16_t *dst, size_t src_strid
#endif
#ifdef ENABLE_ARM64
void Transpose12x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride);
void Transpose16x8ARM64Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride);
#endif

View File

@ -113,17 +113,22 @@ int Convolution1x1FP16CPUKernel::InitWeightBias() {
}
void *weight_origin_tmp = IsTrainable() ? weight_tensor->data_c() : origin_weight_;
memset(reinterpret_cast<char *>(weight_ptr_) + down_size, 0, size - down_size);
#ifdef ENABLE_ARM64
RowMajor2Col16MajorFp16Opt(static_cast<const float16_t *>(weight_origin_tmp), weight_ptr_, output_channel,
input_channel);
#else
ColMajor2Row8MajorFp16(weight_origin_tmp, weight_ptr_, input_channel, output_channel, true);
#endif
return RET_OK;
}
int Convolution1x1FP16CPUKernel::Init() {
col_tile_ = C8NUM;
#ifdef ENABLE_ARM64
row_tile_ = C16NUM;
row_tile_ = C12NUM;
col_tile_ = C16NUM;
#else
row_tile_ = C12NUM;
col_tile_ = C8NUM;
#endif
matmul_param_ = new (std::nothrow) MatMulParameter();
if (matmul_param_ == nullptr) {
@ -174,11 +179,15 @@ int Convolution1x1FP16CPUKernel::RunOc(int task_id) {
}
auto bias = (bias_data_ == nullptr) ? nullptr : reinterpret_cast<float16_t *>(bias_data_) + thread_stride_ * task_id;
MatMulFp16(pack_input_, weight_ptr_ + task_id * thread_stride_ * matmul_param_->deep_,
#ifdef ENABLE_ARM64
MatMul12x16Fp16Opt(pack_input_, weight_ptr_ + task_id * thread_stride_ * matmul_param_->deep_,
output_ptr_ + task_id * thread_stride_, bias, matmul_param_->act_type_, matmul_param_->deep_,
matmul_param_->row_, cur_oc, matmul_param_->col_, OutType_Nhwc);
#else
MatMul12x8A32Fp16(pack_input_, weight_ptr_ + task_id * thread_stride_ * matmul_param_->deep_,
output_ptr_ + task_id * thread_stride_, bias, matmul_param_->act_type_, matmul_param_->deep_,
matmul_param_->row_, cur_oc, matmul_param_->col_, OutType_Nhwc);
#endif
return RET_OK;
}
@ -191,16 +200,18 @@ int Convolution1x1FP16CPUKernel::RunHw(int task_id) {
float16_t *thread_input_ptr = input_ptr_ + task_id * thread_stride_ * matmul_param_->deep_;
float16_t *thread_pack_input = pack_input_ + task_id * thread_stride_ * matmul_param_->deep_;
#ifdef ENABLE_ARM64
RowMajor2Col16MajorFp16Opt(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_);
#else
RowMajor2Col12MajorFp16Opt(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_);
#endif
float16_t *thread_output_ptr = output_ptr_ + task_id * thread_stride_ * matmul_param_->col_;
MatMulFp16(thread_pack_input, weight_ptr_, thread_output_ptr, reinterpret_cast<float16_t *>(bias_data_),
#ifdef ENABLE_ARM64
MatMul12x16Fp16Opt(thread_pack_input, weight_ptr_, thread_output_ptr, reinterpret_cast<float16_t *>(bias_data_),
matmul_param_->act_type_, matmul_param_->deep_, cur_hw_, matmul_param_->col_, matmul_param_->col_,
OutType_Nhwc);
#else
MatMul12x8A32Fp16(thread_pack_input, weight_ptr_, thread_output_ptr, reinterpret_cast<float16_t *>(bias_data_),
matmul_param_->act_type_, matmul_param_->deep_, cur_hw_, matmul_param_->col_, matmul_param_->col_,
OutType_Nhwc);
#endif
return RET_OK;
}
@ -263,11 +274,7 @@ int Convolution1x1FP16CPUKernel::Run() {
if (multi_thread_by_hw_) {
ret = ParallelLaunch(this->context_, Convolution1x1Fp16RunHw, this, thread_count_);
} else {
#ifdef ENABLE_ARM64
RowMajor2Col16MajorFp16Opt(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
#else
RowMajor2Col12MajorFp16Opt(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
#endif
ret = ParallelLaunch(this->context_, Convolution1x1Fp16RunOc, this, thread_count_);
}
if (ret != RET_OK) {