forked from mindspore-Ecosystem/mindspore
fix vec matmul fp32
This commit is contained in:
parent
7458b4a099
commit
7cc485f2de
|
@ -454,7 +454,7 @@ void VecMatmulFp16(const float16_t *a, const float16_t *b, float16_t *c, const f
|
|||
vst1q_f16(c, acc_0);
|
||||
|
||||
if (ci + C16NUM > col) {
|
||||
int c_remain = col - ci;
|
||||
int c_remain = col - ci - C8NUM;
|
||||
for (int i = 0; i < c_remain; ++i) {
|
||||
if (act_type == ActType_Relu) {
|
||||
c[C8NUM + i] = MSMAX(acc_1[i], (float16_t)0.0);
|
||||
|
|
|
@ -955,7 +955,7 @@ void MatVecMulFp32Neon64(const float *a, const float *b, float *c, const float *
|
|||
}
|
||||
vst1q_f32(c, acc_0);
|
||||
if (ci + C8NUM - 1 >= col) {
|
||||
int c_remain = col - ci;
|
||||
int c_remain = col - ci - C4NUM;
|
||||
for (int i = 0; i < c_remain; ++i) {
|
||||
if (act_type == ActType_Relu) {
|
||||
c[C4NUM + i] = MSMAX(acc_1[i], 0.0f);
|
||||
|
|
|
@ -54,7 +54,7 @@ void MatmulFp32BaseCPUKernel::ResizeParameter() {
|
|||
// vector matmul col is aligned to C8NUM in avx
|
||||
col_tile_ = C8NUM;
|
||||
#elif defined(ENABLE_ARM64)
|
||||
col_tile_ = 1;
|
||||
col_tile_ = C8NUM;
|
||||
#endif
|
||||
row_tile_ = 1;
|
||||
}
|
||||
|
@ -64,7 +64,7 @@ void MatmulFp32BaseCPUKernel::ResizeParameter() {
|
|||
params_->col_align_ = UP_ROUND(params_->col_, col_tile_);
|
||||
#elif defined(ENABLE_ARM64)
|
||||
// no matter vec_matmul_ or not, use col_tile_ to get col_align_
|
||||
params_->col_align_ = vec_matmul_ ? params_->col_ : UP_ROUND(params_->col_, col_tile_);
|
||||
params_->col_align_ = UP_ROUND(params_->col_, col_tile_);
|
||||
#else
|
||||
params_->col_align_ = vec_matmul_ ? params_->col_ : UP_ROUND(params_->col_, col_tile_);
|
||||
#endif
|
||||
|
@ -176,7 +176,7 @@ int MatmulFp32BaseCPUKernel::InitMatrixB(const float *src_ptr) {
|
|||
#ifdef ENABLE_AVX
|
||||
RowMajor2Col32Major(src_data, dst, params_->deep_, params_->col_);
|
||||
#elif defined(ENABLE_ARM64)
|
||||
memcpy(dst, src_data, params_->col_ * params_->deep_ * sizeof(float));
|
||||
RowMajor2Col8Major(src_data, dst, params_->col_, params_->deep_);
|
||||
#else
|
||||
memcpy(dst, src_data, params_->col_ * params_->deep_ * sizeof(float));
|
||||
#endif
|
||||
|
@ -184,7 +184,7 @@ int MatmulFp32BaseCPUKernel::InitMatrixB(const float *src_ptr) {
|
|||
#ifdef ENABLE_AVX
|
||||
RowMajor2Row32Major(src_data, dst, params_->col_, params_->deep_);
|
||||
#elif defined(ENABLE_ARM64)
|
||||
RowMajor2ColMajor(src_data, dst, params_->deep_, params_->col_);
|
||||
RowMajor2Row8Major(src_data, dst, params_->deep_, params_->col_);
|
||||
#else
|
||||
RowMajor2ColMajor(src_data, dst, params_->deep_, params_->col_);
|
||||
#endif
|
||||
|
@ -258,7 +258,8 @@ int MatmulFp32BaseCPUKernel::FloatRun(int task_id) {
|
|||
#ifdef ENABLE_AVX
|
||||
MatVecMulAvxFp32(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc, params_->col_align_);
|
||||
#elif defined(ENABLE_ARM64)
|
||||
MatVecMulFp32(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc);
|
||||
MatVecMulFp32Neon64(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc,
|
||||
thread_stride_ * col_tile_);
|
||||
#else
|
||||
MatVecMulFp32(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc);
|
||||
#endif
|
||||
|
|
Loading…
Reference in New Issue