!19411 down fp32 vecmatmul

Merge pull request !19411 from zhaozhenlong/lite/issue/down_fp32_vecmatmul
This commit is contained in:
i-robot 2021-07-06 06:47:18 +00:00 committed by Gitee
commit faab555824
1 changed files with 5 additions and 5 deletions

View File

@ -54,7 +54,7 @@ void MatmulFp32BaseCPUKernel::ResizeParameter() {
// vector matmul col is aligned to C8NUM in avx
col_tile_ = C8NUM;
#elif defined(ENABLE_ARM64)
col_tile_ = C8NUM;
col_tile_ = 1;
#endif
row_tile_ = 1;
}
@ -64,7 +64,7 @@ void MatmulFp32BaseCPUKernel::ResizeParameter() {
params_->col_align_ = UP_ROUND(params_->col_, col_tile_);
#elif defined(ENABLE_ARM64)
// no matter vec_matmul_ or not, use col_tile_ to get col_align_
params_->col_align_ = UP_ROUND(params_->col_, col_tile_);
params_->col_align_ = vec_matmul_ ? params_->col_ : UP_ROUND(params_->col_, col_tile_);
#else
params_->col_align_ = vec_matmul_ ? params_->col_ : UP_ROUND(params_->col_, col_tile_);
#endif
@ -176,7 +176,7 @@ int MatmulFp32BaseCPUKernel::InitMatrixB(const float *src_ptr) {
#ifdef ENABLE_AVX
RowMajor2Col32Major(src_data, dst, params_->deep_, params_->col_);
#elif defined(ENABLE_ARM64)
RowMajor2Col8Major(src_data, dst, params_->col_, params_->deep_);
memcpy(dst, src_data, params_->col_ * params_->deep_ * sizeof(float));
#else
memcpy(dst, src_data, params_->col_ * params_->deep_ * sizeof(float));
#endif
@ -184,7 +184,7 @@ int MatmulFp32BaseCPUKernel::InitMatrixB(const float *src_ptr) {
#ifdef ENABLE_AVX
RowMajor2Row32Major(src_data, dst, params_->col_, params_->deep_);
#elif defined(ENABLE_ARM64)
RowMajor2Row8Major(src_data, dst, params_->deep_, params_->col_);
RowMajor2ColMajor(src_data, dst, params_->deep_, params_->col_);
#else
RowMajor2ColMajor(src_data, dst, params_->deep_, params_->col_);
#endif
@ -258,7 +258,7 @@ int MatmulFp32BaseCPUKernel::FloatRun(int task_id) {
#ifdef ENABLE_AVX
MatVecMulAvxFp32(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc, params_->col_align_);
#elif defined(ENABLE_ARM64)
MatVecMulFp32Neon64(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc, params_->col_align_);
MatVecMulFp32(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc);
#else
MatVecMulFp32(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc);
#endif