fix vec matmul fp32

This commit is contained in:
zhaozhenlong 2021-07-05 21:30:13 +08:00
parent 7458b4a099
commit 7cc485f2de
3 changed files with 8 additions and 7 deletions

View File

@ -454,7 +454,7 @@ void VecMatmulFp16(const float16_t *a, const float16_t *b, float16_t *c, const f
vst1q_f16(c, acc_0);
if (ci + C16NUM > col) {
int c_remain = col - ci;
int c_remain = col - ci - C8NUM;
for (int i = 0; i < c_remain; ++i) {
if (act_type == ActType_Relu) {
c[C8NUM + i] = MSMAX(acc_1[i], (float16_t)0.0);

View File

@ -955,7 +955,7 @@ void MatVecMulFp32Neon64(const float *a, const float *b, float *c, const float *
}
vst1q_f32(c, acc_0);
if (ci + C8NUM - 1 >= col) {
int c_remain = col - ci;
int c_remain = col - ci - C4NUM;
for (int i = 0; i < c_remain; ++i) {
if (act_type == ActType_Relu) {
c[C4NUM + i] = MSMAX(acc_1[i], 0.0f);

View File

@ -54,7 +54,7 @@ void MatmulFp32BaseCPUKernel::ResizeParameter() {
// vector matmul col is aligned to C8NUM in avx
col_tile_ = C8NUM;
#elif defined(ENABLE_ARM64)
col_tile_ = 1;
col_tile_ = C8NUM;
#endif
row_tile_ = 1;
}
@ -64,7 +64,7 @@ void MatmulFp32BaseCPUKernel::ResizeParameter() {
params_->col_align_ = UP_ROUND(params_->col_, col_tile_);
#elif defined(ENABLE_ARM64)
// no matter vec_matmul_ or not, use col_tile_ to get col_align_
params_->col_align_ = vec_matmul_ ? params_->col_ : UP_ROUND(params_->col_, col_tile_);
params_->col_align_ = UP_ROUND(params_->col_, col_tile_);
#else
params_->col_align_ = vec_matmul_ ? params_->col_ : UP_ROUND(params_->col_, col_tile_);
#endif
@ -176,7 +176,7 @@ int MatmulFp32BaseCPUKernel::InitMatrixB(const float *src_ptr) {
#ifdef ENABLE_AVX
RowMajor2Col32Major(src_data, dst, params_->deep_, params_->col_);
#elif defined(ENABLE_ARM64)
memcpy(dst, src_data, params_->col_ * params_->deep_ * sizeof(float));
RowMajor2Col8Major(src_data, dst, params_->col_, params_->deep_);
#else
memcpy(dst, src_data, params_->col_ * params_->deep_ * sizeof(float));
#endif
@ -184,7 +184,7 @@ int MatmulFp32BaseCPUKernel::InitMatrixB(const float *src_ptr) {
#ifdef ENABLE_AVX
RowMajor2Row32Major(src_data, dst, params_->col_, params_->deep_);
#elif defined(ENABLE_ARM64)
RowMajor2ColMajor(src_data, dst, params_->deep_, params_->col_);
RowMajor2Row8Major(src_data, dst, params_->deep_, params_->col_);
#else
RowMajor2ColMajor(src_data, dst, params_->deep_, params_->col_);
#endif
@ -258,7 +258,8 @@ int MatmulFp32BaseCPUKernel::FloatRun(int task_id) {
#ifdef ENABLE_AVX
MatVecMulAvxFp32(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc, params_->col_align_);
#elif defined(ENABLE_ARM64)
MatVecMulFp32(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc);
MatVecMulFp32Neon64(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc,
thread_stride_ * col_tile_);
#else
MatVecMulFp32(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc);
#endif