!18601 [MS][LITE][Develop] optimize conv_1x1 kernel
Merge pull request !18601 from sunsuodong/optimize_lite_kernel
This commit is contained in:
commit
8644962596
File diff suppressed because it is too large
Load Diff
|
@ -350,6 +350,29 @@ void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, cons
|
|||
}
|
||||
}
|
||||
|
||||
#ifdef ENABLE_DEBUG
|
||||
void MatMul12x16Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
|
||||
int deep, int row, int col, size_t stride, size_t out_type) {
|
||||
for (int r = 0; r < row; r++) {
|
||||
for (int c = 0; c < col; c++) {
|
||||
int r12div = r / C12NUM, r12mod = r % C12NUM;
|
||||
int c16div = c / C16NUM, c16mod = c % C16NUM;
|
||||
size_t index = r * stride + c;
|
||||
float16_t value = 0;
|
||||
for (int d = 0; d < deep; d++) {
|
||||
size_t ai = r12div * deep * C12NUM + d * C12NUM + r12mod;
|
||||
size_t bi = c16div * deep * C16NUM + d * C16NUM + c16mod;
|
||||
value = value + a[ai] * b[bi];
|
||||
}
|
||||
ADD_BIAS(value, bias, c)
|
||||
DO_RELU(value, act_type)
|
||||
DO_RELU6(value, act_type)
|
||||
dst[index] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type,
|
||||
int depth, int row, int col, int stride, int out_type) {
|
||||
if (out_type == OutType_C8) {
|
||||
|
@ -567,7 +590,9 @@ void RowMajor2Col12MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, si
|
|||
for (; ci < col8; ci += C8NUM) {
|
||||
const float16_t *src_c = src_r + ci;
|
||||
float16_t *dst_c = dst_r + ci * C12NUM;
|
||||
#ifdef ENABLE_ARM82_A32
|
||||
#ifdef ENABLE_ARM64
|
||||
Transpose12x8ARM64Fp16(src_c, dst_c, col * sizeof(float16_t), 24);
|
||||
#elif ENABLE_ARM82_A32
|
||||
Transpose12x8A32Fp16(src_c, dst_c, col * sizeof(float16_t), 24);
|
||||
#else
|
||||
for (int tr = 0; tr < C12NUM; tr++) {
|
||||
|
|
|
@ -45,6 +45,8 @@ void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, cons
|
|||
int deep, int row, int col, int stride, int write_mode);
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void MatMul12x16Fp16Opt(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
|
||||
int deep, int row, int col, size_t stride, size_t out_type);
|
||||
void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
|
||||
size_t depth, size_t row, size_t col, size_t stride, bool write_nhwc);
|
||||
|
||||
|
@ -65,6 +67,9 @@ void MatVecMulA32NeonFp16(const float16_t *a, const float16_t *b, float16_t *c,
|
|||
int depth, int col);
|
||||
#endif
|
||||
|
||||
void MatMul12x16Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
|
||||
int deep, int row, int col, size_t stride, size_t out_type);
|
||||
|
||||
void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type,
|
||||
int depth, int row, int col, int stride, int out_type);
|
||||
|
||||
|
|
|
@ -640,6 +640,109 @@ inline void Transpose12x8A32Fp16(const float16_t *src_c, float16_t *dst_c, size_
|
|||
#endif
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void Transpose12x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride) {
|
||||
#ifdef ENABLE_DEBUG
|
||||
for (int tr = 0; tr < C12NUM; tr++) {
|
||||
for (int tc = 0; tc < C8NUM; tc++) {
|
||||
dst_ptr[tc * C12NUM + tr] = src_ptr[tr * col + tc];
|
||||
}
|
||||
}
|
||||
#else
|
||||
asm volatile(
|
||||
"mov x10, %[src_ptr]\n"
|
||||
"mov x11, %[dst_ptr]\n"
|
||||
|
||||
"ld1 {v0.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v1.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v2.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v3.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v4.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v5.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v6.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v7.8h}, [x10], %[src_stride]\n"
|
||||
|
||||
"zip1 v16.8h, v0.8h, v1.8h\n"
|
||||
"zip1 v17.8h, v2.8h, v3.8h\n"
|
||||
"zip1 v18.8h, v4.8h, v5.8h\n"
|
||||
"zip1 v19.8h, v6.8h, v7.8h\n"
|
||||
|
||||
"ld1 {v8.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v9.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v10.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v11.8h}, [x10], %[src_stride]\n"
|
||||
|
||||
"trn1 v20.4s, v16.4s, v17.4s\n"
|
||||
"trn2 v21.4s, v16.4s, v17.4s\n"
|
||||
"trn1 v22.4s, v18.4s, v19.4s\n"
|
||||
"trn2 v23.4s, v18.4s, v19.4s\n"
|
||||
|
||||
"trn1 v24.2d, v20.2d, v22.2d\n"
|
||||
"trn2 v25.2d, v20.2d, v22.2d\n"
|
||||
"trn1 v26.2d, v21.2d, v23.2d\n"
|
||||
"trn2 v27.2d, v21.2d, v23.2d\n"
|
||||
|
||||
"zip1 v16.8h, v8.8h, v9.8h\n"
|
||||
"zip1 v17.8h, v10.8h, v11.8h\n"
|
||||
|
||||
"trn1 v20.4s, v16.4s, v17.4s\n"
|
||||
"trn2 v21.4s, v16.4s, v17.4s\n"
|
||||
|
||||
"trn1 v28.2d, v20.2d, v20.2d\n"
|
||||
"trn2 v29.2d, v20.2d, v20.2d\n"
|
||||
"trn1 v30.2d, v21.2d, v21.2d\n"
|
||||
"trn2 v31.2d, v21.2d, v21.2d\n"
|
||||
|
||||
"add x10, x11, #16\n"
|
||||
"st1 {v24.8h}, [x11], %[dst_stride]\n"
|
||||
"st1 {v28.4h}, [x10], %[dst_stride]\n"
|
||||
"st1 {v26.8h}, [x11], %[dst_stride]\n"
|
||||
"st1 {v30.4h}, [x10], %[dst_stride]\n"
|
||||
"st1 {v25.8h}, [x11], %[dst_stride]\n"
|
||||
"st1 {v29.4h}, [x10], %[dst_stride]\n"
|
||||
"st1 {v27.8h}, [x11], %[dst_stride]\n"
|
||||
"st1 {v31.4h}, [x10], %[dst_stride]\n"
|
||||
|
||||
"zip2 v16.8h, v0.8h, v1.8h\n"
|
||||
"zip2 v17.8h, v2.8h, v3.8h\n"
|
||||
"zip2 v18.8h, v4.8h, v5.8h\n"
|
||||
"zip2 v19.8h, v6.8h, v7.8h\n"
|
||||
|
||||
"trn1 v20.4s, v16.4s, v17.4s\n"
|
||||
"trn2 v21.4s, v16.4s, v17.4s\n"
|
||||
"trn1 v22.4s, v18.4s, v19.4s\n"
|
||||
"trn2 v23.4s, v18.4s, v19.4s\n"
|
||||
|
||||
"trn1 v24.2d, v20.2d, v22.2d\n"
|
||||
"trn2 v25.2d, v20.2d, v22.2d\n"
|
||||
"trn1 v26.2d, v21.2d, v23.2d\n"
|
||||
"trn2 v27.2d, v21.2d, v23.2d\n"
|
||||
|
||||
"zip2 v16.8h, v8.8h, v9.8h\n"
|
||||
"zip2 v17.8h, v10.8h, v11.8h\n"
|
||||
|
||||
"trn1 v20.4s, v16.4s, v17.4s\n"
|
||||
"trn2 v21.4s, v16.4s, v17.4s\n"
|
||||
|
||||
"trn1 v28.2d, v20.2d, v20.2d\n"
|
||||
"trn2 v29.2d, v20.2d, v20.2d\n"
|
||||
"trn1 v30.2d, v21.2d, v21.2d\n"
|
||||
"trn2 v31.2d, v21.2d, v21.2d\n"
|
||||
|
||||
"st1 {v24.8h}, [x11], %[dst_stride]\n"
|
||||
"st1 {v28.4h}, [x10], %[dst_stride]\n"
|
||||
"st1 {v26.8h}, [x11], %[dst_stride]\n"
|
||||
"st1 {v30.4h}, [x10], %[dst_stride]\n"
|
||||
"st1 {v25.8h}, [x11], %[dst_stride]\n"
|
||||
"st1 {v29.4h}, [x10], %[dst_stride]\n"
|
||||
"st1 {v27.8h}, [x11], %[dst_stride]\n"
|
||||
"st1 {v31.4h}, [x10], %[dst_stride]\n"
|
||||
:
|
||||
: [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ src_stride ] "r"(src_stride), [ dst_stride ] "r"(dst_stride)
|
||||
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v16", "v17", "v18",
|
||||
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
|
||||
#endif
|
||||
}
|
||||
|
||||
inline void Transpose16x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride) {
|
||||
asm volatile(
|
||||
"mov x10, %[src_ptr]\n"
|
||||
|
|
|
@ -78,6 +78,7 @@ void Transpose12x8A32Fp16(const float16_t *src, float16_t *dst, size_t src_strid
|
|||
#endif
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void Transpose12x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride);
|
||||
void Transpose16x8ARM64Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride);
|
||||
#endif
|
||||
|
||||
|
|
|
@ -113,17 +113,22 @@ int Convolution1x1FP16CPUKernel::InitWeightBias() {
|
|||
}
|
||||
void *weight_origin_tmp = IsTrainable() ? weight_tensor->data_c() : origin_weight_;
|
||||
memset(reinterpret_cast<char *>(weight_ptr_) + down_size, 0, size - down_size);
|
||||
#ifdef ENABLE_ARM64
|
||||
RowMajor2Col16MajorFp16Opt(static_cast<const float16_t *>(weight_origin_tmp), weight_ptr_, output_channel,
|
||||
input_channel);
|
||||
#else
|
||||
ColMajor2Row8MajorFp16(weight_origin_tmp, weight_ptr_, input_channel, output_channel, true);
|
||||
|
||||
#endif
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int Convolution1x1FP16CPUKernel::Init() {
|
||||
col_tile_ = C8NUM;
|
||||
#ifdef ENABLE_ARM64
|
||||
row_tile_ = C16NUM;
|
||||
row_tile_ = C12NUM;
|
||||
col_tile_ = C16NUM;
|
||||
#else
|
||||
row_tile_ = C12NUM;
|
||||
col_tile_ = C8NUM;
|
||||
#endif
|
||||
matmul_param_ = new (std::nothrow) MatMulParameter();
|
||||
if (matmul_param_ == nullptr) {
|
||||
|
@ -174,11 +179,15 @@ int Convolution1x1FP16CPUKernel::RunOc(int task_id) {
|
|||
}
|
||||
|
||||
auto bias = (bias_data_ == nullptr) ? nullptr : reinterpret_cast<float16_t *>(bias_data_) + thread_stride_ * task_id;
|
||||
|
||||
MatMulFp16(pack_input_, weight_ptr_ + task_id * thread_stride_ * matmul_param_->deep_,
|
||||
#ifdef ENABLE_ARM64
|
||||
MatMul12x16Fp16Opt(pack_input_, weight_ptr_ + task_id * thread_stride_ * matmul_param_->deep_,
|
||||
output_ptr_ + task_id * thread_stride_, bias, matmul_param_->act_type_, matmul_param_->deep_,
|
||||
matmul_param_->row_, cur_oc, matmul_param_->col_, OutType_Nhwc);
|
||||
|
||||
#else
|
||||
MatMul12x8A32Fp16(pack_input_, weight_ptr_ + task_id * thread_stride_ * matmul_param_->deep_,
|
||||
output_ptr_ + task_id * thread_stride_, bias, matmul_param_->act_type_, matmul_param_->deep_,
|
||||
matmul_param_->row_, cur_oc, matmul_param_->col_, OutType_Nhwc);
|
||||
#endif
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -191,16 +200,18 @@ int Convolution1x1FP16CPUKernel::RunHw(int task_id) {
|
|||
|
||||
float16_t *thread_input_ptr = input_ptr_ + task_id * thread_stride_ * matmul_param_->deep_;
|
||||
float16_t *thread_pack_input = pack_input_ + task_id * thread_stride_ * matmul_param_->deep_;
|
||||
#ifdef ENABLE_ARM64
|
||||
RowMajor2Col16MajorFp16Opt(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_);
|
||||
#else
|
||||
RowMajor2Col12MajorFp16Opt(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_);
|
||||
#endif
|
||||
|
||||
float16_t *thread_output_ptr = output_ptr_ + task_id * thread_stride_ * matmul_param_->col_;
|
||||
MatMulFp16(thread_pack_input, weight_ptr_, thread_output_ptr, reinterpret_cast<float16_t *>(bias_data_),
|
||||
#ifdef ENABLE_ARM64
|
||||
MatMul12x16Fp16Opt(thread_pack_input, weight_ptr_, thread_output_ptr, reinterpret_cast<float16_t *>(bias_data_),
|
||||
matmul_param_->act_type_, matmul_param_->deep_, cur_hw_, matmul_param_->col_, matmul_param_->col_,
|
||||
OutType_Nhwc);
|
||||
|
||||
#else
|
||||
MatMul12x8A32Fp16(thread_pack_input, weight_ptr_, thread_output_ptr, reinterpret_cast<float16_t *>(bias_data_),
|
||||
matmul_param_->act_type_, matmul_param_->deep_, cur_hw_, matmul_param_->col_, matmul_param_->col_,
|
||||
OutType_Nhwc);
|
||||
#endif
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -263,11 +274,7 @@ int Convolution1x1FP16CPUKernel::Run() {
|
|||
if (multi_thread_by_hw_) {
|
||||
ret = ParallelLaunch(this->context_, Convolution1x1Fp16RunHw, this, thread_count_);
|
||||
} else {
|
||||
#ifdef ENABLE_ARM64
|
||||
RowMajor2Col16MajorFp16Opt(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
|
||||
#else
|
||||
RowMajor2Col12MajorFp16Opt(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
|
||||
#endif
|
||||
ret = ParallelLaunch(this->context_, Convolution1x1Fp16RunOc, this, thread_count_);
|
||||
}
|
||||
if (ret != RET_OK) {
|
||||
|
|
Loading…
Reference in New Issue