diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc index 8506f05eb7c..8f131e80bb4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc @@ -22,12 +22,12 @@ using mindspore::lite::RET_NULL_PTR; namespace mindspore::kernel { -int MatmulBaseFloatRun(const void *cdata, int task_id, float, float) { +int MatmulRun(const void *cdata, int task_id, float, float) { CHECK_NULL_RETURN(cdata); auto op = reinterpret_cast(cdata); - auto error_code = op->FloatRun(task_id); + auto error_code = (op->*(op->parallel_fun_))(task_id); if (error_code != RET_OK) { - MS_LOG(ERROR) << "MatmulFp32Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + MS_LOG(ERROR) << "MatmulRun error task_id[" << task_id << "] error_code[" << error_code << "]"; return RET_ERROR; } return RET_OK; @@ -116,7 +116,7 @@ int MatmulFp32BaseCPUKernel::InitBiasData() { MS_LOG(ERROR) << "bias_tensor invalid"; return RET_ERROR; } - size_t bias_num = static_cast(bias_tensor->ElementsNum()); + auto bias_num = static_cast(bias_tensor->ElementsNum()); MS_CHECK_TRUE_RET(bias_num > 0, RET_ERROR); if (bias_num == 1) { // broadcast bias data @@ -134,7 +134,7 @@ int MatmulFp32BaseCPUKernel::InitBiasData() { return RET_OK; } - size_t max_bias_data = static_cast(UP_ROUND(bias_num, col_tile_)); + auto max_bias_data = static_cast(UP_ROUND(bias_num, col_tile_)); // malloc addr need to aligned to 32 bytes bias_ptr_ = reinterpret_cast(malloc(max_bias_data * static_cast(sizeof(float)))); if (bias_ptr_ == nullptr) { @@ -206,15 +206,52 @@ void MatmulFp32BaseCPUKernel::FreeResizeBufB() { } } -int MatmulFp32BaseCPUKernel::FloatRun(int task_id) const { - int current_start_oc = task_id * thread_stride_ * col_tile_; - int current_rest_oc = 0; -#if defined(ENABLE_AVX) - current_rest_oc = params_->col_align_ - current_start_oc; +int MatmulFp32BaseCPUKernel::ParallelRunByBatch(int task_id) const { + int start_batch = task_id * batch_stride_; + int end_batch = MSMIN(params_->batch, start_batch + batch_stride_); +#ifdef ENABLE_AVX + int col_step = params_->col_align_; #else - current_rest_oc = params_->col_ - current_start_oc; + // col need not aligned + int col_step = params_->col_; #endif - int cur_oc = MSMIN(thread_stride_ * col_tile_, current_rest_oc); + + for (int index = start_batch; index < end_batch; ++index) { + const float *a = a_pack_ptr_ + a_offset_[index] * params_->row_align_ * params_->deep_; + const float *b = b_pack_ptr_ + b_offset_[index] * params_->deep_ * params_->col_align_; + float *c = output_data_ + index * params_->row_ * col_step; + + auto bias = (bias_ptr_ == nullptr) ? nullptr : bias_ptr_; + if (vec_matmul_) { +#ifdef ENABLE_AVX + MatVecMulAvxFp32(a, b, c, bias, params_->act_type_, params_->deep_, col_step, params_->col_align_); +#elif defined(ENABLE_ARM64) + MatVecMulFp32Neon64(a, b, c, bias, params_->act_type_, params_->deep_, col_step, params_->col_align_); +#elif defined(ENABLE_ARM32) + MatVecMulFp32Block4(a, b, c, bias, params_->act_type_, params_->deep_, col_step); +#else + MatVecMulFp32Block8(a, b, c, bias, params_->act_type_, params_->deep_, col_step); +#endif + } else { +#ifdef ENABLE_AVX + MatMulAvxFp32(a, b, c, bias, params_->act_type_, params_->deep_, col_step, params_->col_align_, params_->row_); +#else + MatMulOpt(a, b, c, bias, params_->act_type_, params_->deep_, params_->row_, col_step, params_->col_, + OutType_Nhwc); +#endif + } + } + return RET_OK; +} + +int MatmulFp32BaseCPUKernel::ParallelRunByOC(int task_id) const { + int current_start_oc = task_id * oc_stride_ * col_tile_; +#if defined(ENABLE_AVX) + int current_rest_oc = params_->col_align_ - current_start_oc; +#else + int current_rest_oc = params_->col_ - current_start_oc; +#endif + int cur_oc = MSMIN(oc_stride_ * col_tile_, current_rest_oc); if (cur_oc <= 0) { return RET_OK; } @@ -226,7 +263,7 @@ int MatmulFp32BaseCPUKernel::FloatRun(int task_id) const { #ifdef ENABLE_AVX MatVecMulAvxFp32(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc, params_->col_align_); #elif defined(ENABLE_ARM64) - int rest_align_col = MSMIN(params_->col_align_ - current_start_oc, thread_stride_ * col_tile_); + int rest_align_col = MSMIN(params_->col_align_ - current_start_oc, oc_stride_ * col_tile_); MatVecMulFp32Neon64(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc, rest_align_col); #elif defined(ENABLE_ARM32) MatVecMulFp32Block4(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc); @@ -323,13 +360,7 @@ int MatmulFp32BaseCPUKernel::ReSize() { if (op_parameter_->is_train_session_) { set_workspace_size((matrix_a_pack_size_ + matrix_b_pack_size_) * static_cast(sizeof(float))); } - thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(params_->col_align_, col_tile_)); -#if defined(ENABLE_AVX) // thread tile by col_tile * C4NUM - thread_stride_ = UP_DIV(UP_DIV(params_->col_align_, col_tile_ * C4NUM), thread_count_) * C4NUM; -#else - thread_stride_ = UP_DIV(UP_DIV(params_->col_align_, col_tile_), thread_count_); -#endif - + GetThreadCuttingPolicy(); return RET_OK; } @@ -368,6 +399,24 @@ int MatmulFp32BaseCPUKernel::InitTmpOutBuffer() { return RET_OK; } +void MatmulFp32BaseCPUKernel::GetThreadCuttingPolicy() { + if (params_->batch >= op_parameter_->thread_num_) { + thread_count_ = op_parameter_->thread_num_; + batch_stride_ = UP_DIV(params_->batch, thread_count_); + batch_split_ = true; + parallel_fun_ = &MatmulFp32BaseCPUKernel::ParallelRunByBatch; + } else { + thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(params_->col_align_, col_tile_)); +#if defined(ENABLE_AVX) // thread tile by col_tile * C4NUM + oc_stride_ = UP_DIV(UP_DIV(params_->col_align_, col_tile_ * C4NUM), thread_count_) * C4NUM; +#else + oc_stride_ = UP_DIV(UP_DIV(params_->col_align_, col_tile_), thread_count_); +#endif + batch_split_ = false; + parallel_fun_ = &MatmulFp32BaseCPUKernel::ParallelRunByOC; + } +} + int MatmulFp32BaseCPUKernel::Run() { if (!params_->a_const_) { auto a_ptr = reinterpret_cast(in_tensors_[0]->data()); @@ -403,20 +452,28 @@ int MatmulFp32BaseCPUKernel::Run() { return ret; } - for (int i = 0; i < params_->batch; ++i) { - batch_a_ptr_ = a_pack_ptr_ + a_offset_[i] * params_->row_align_ * params_->deep_; - batch_b_ptr_ = b_pack_ptr_ + b_offset_[i] * params_->deep_ * params_->col_align_; + if (batch_split_) { + ret = ParallelLaunch(this->ms_context_, MatmulRun, this, thread_count_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "MatmulRun failed in split by batch"; + return ret; + } + } else { #ifdef ENABLE_AVX - batch_c_ptr_ = output_data_ + i * params_->row_ * params_->col_align_; + int col_step = params_->col_align_; #else // need not aligned - batch_c_ptr_ = output_data_ + i * params_->row_ * params_->col_; + int col_step = params_->col_; #endif - - ret = ParallelLaunch(this->ms_context_, MatmulBaseFloatRun, this, thread_count_); - if (ret != RET_OK) { - MS_LOG(ERROR) << "MatmulBaseFloatRun failed"; - return ret; + for (int i = 0; i < params_->batch; ++i) { + batch_a_ptr_ = a_pack_ptr_ + a_offset_[i] * params_->row_align_ * params_->deep_; + batch_b_ptr_ = b_pack_ptr_ + b_offset_[i] * params_->deep_ * params_->col_align_; + batch_c_ptr_ = output_data_ + i * params_->row_ * col_step; + ret = ParallelLaunch(this->ms_context_, MatmulRun, this, thread_count_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "MatmulRun failed in split by oc"; + return ret; + } } } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.h b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.h index ed45554e999..7da9ad35505 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.h @@ -43,7 +43,10 @@ class MatmulFp32BaseCPUKernel : public InnerKernel { int Run() override; public: - int FloatRun(int task_id) const; + int ParallelRunByOC(int task_id) const; + int ParallelRunByBatch(int task_id) const; + using ParallelRun = int (MatmulFp32BaseCPUKernel::*)(int task_id) const; + ParallelRun parallel_fun_ = nullptr; protected: int InitBufferA(); @@ -61,6 +64,7 @@ class MatmulFp32BaseCPUKernel : public InnerKernel { void FreeResizeBufB(); int CalBroadCastBiasDataElements(); int InitTmpOutBuffer(); + void GetThreadCuttingPolicy(); protected: MatMulParameter *params_ = nullptr; @@ -75,7 +79,8 @@ class MatmulFp32BaseCPUKernel : public InnerKernel { int col_tile_ = 0; int row_tile_ = 0; int oc_res_ = 0; - int thread_stride_ = 0; + int batch_stride_ = 0; + int oc_stride_ = 0; int thread_count_ = 0; bool vec_matmul_ = false; float *bias_ptr_ = nullptr; @@ -87,6 +92,7 @@ class MatmulFp32BaseCPUKernel : public InnerKernel { int matrix_b_pack_size_ = -1; MatrixPackFun matrix_a_pack_fun_ = nullptr; MatrixPackFun matrix_b_pack_fun_ = nullptr; + bool batch_split_ = false; }; } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_MATMUL_FP32_BASE_H_