!26445 [MS][LITE][CPU] thread split op

Merge pull request !26445 from liuzhongkai/matmul
This commit is contained in:
i-robot 2021-11-19 08:43:05 +00:00 committed by Gitee
commit 2f709641cd
2 changed files with 95 additions and 32 deletions

View File

@ -22,12 +22,12 @@
using mindspore::lite::RET_NULL_PTR;
namespace mindspore::kernel {
int MatmulBaseFloatRun(const void *cdata, int task_id, float, float) {
int MatmulRun(const void *cdata, int task_id, float, float) {
CHECK_NULL_RETURN(cdata);
auto op = reinterpret_cast<const MatmulFp32BaseCPUKernel *>(cdata);
auto error_code = op->FloatRun(task_id);
auto error_code = (op->*(op->parallel_fun_))(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "MatmulFp32Run error task_id[" << task_id << "] error_code[" << error_code << "]";
MS_LOG(ERROR) << "MatmulRun error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
@ -116,7 +116,7 @@ int MatmulFp32BaseCPUKernel::InitBiasData() {
MS_LOG(ERROR) << "bias_tensor invalid";
return RET_ERROR;
}
size_t bias_num = static_cast<size_t>(bias_tensor->ElementsNum());
auto bias_num = static_cast<size_t>(bias_tensor->ElementsNum());
MS_CHECK_TRUE_RET(bias_num > 0, RET_ERROR);
if (bias_num == 1) {
// broadcast bias data
@ -134,7 +134,7 @@ int MatmulFp32BaseCPUKernel::InitBiasData() {
return RET_OK;
}
size_t max_bias_data = static_cast<size_t>(UP_ROUND(bias_num, col_tile_));
auto max_bias_data = static_cast<size_t>(UP_ROUND(bias_num, col_tile_));
// malloc addr need to aligned to 32 bytes
bias_ptr_ = reinterpret_cast<float *>(malloc(max_bias_data * static_cast<int>(sizeof(float))));
if (bias_ptr_ == nullptr) {
@ -206,15 +206,52 @@ void MatmulFp32BaseCPUKernel::FreeResizeBufB() {
}
}
int MatmulFp32BaseCPUKernel::FloatRun(int task_id) const {
int current_start_oc = task_id * thread_stride_ * col_tile_;
int current_rest_oc = 0;
#if defined(ENABLE_AVX)
current_rest_oc = params_->col_align_ - current_start_oc;
int MatmulFp32BaseCPUKernel::ParallelRunByBatch(int task_id) const {
int start_batch = task_id * batch_stride_;
int end_batch = MSMIN(params_->batch, start_batch + batch_stride_);
#ifdef ENABLE_AVX
int col_step = params_->col_align_;
#else
current_rest_oc = params_->col_ - current_start_oc;
// col need not aligned
int col_step = params_->col_;
#endif
int cur_oc = MSMIN(thread_stride_ * col_tile_, current_rest_oc);
for (int index = start_batch; index < end_batch; ++index) {
const float *a = a_pack_ptr_ + a_offset_[index] * params_->row_align_ * params_->deep_;
const float *b = b_pack_ptr_ + b_offset_[index] * params_->deep_ * params_->col_align_;
float *c = output_data_ + index * params_->row_ * col_step;
auto bias = (bias_ptr_ == nullptr) ? nullptr : bias_ptr_;
if (vec_matmul_) {
#ifdef ENABLE_AVX
MatVecMulAvxFp32(a, b, c, bias, params_->act_type_, params_->deep_, col_step, params_->col_align_);
#elif defined(ENABLE_ARM64)
MatVecMulFp32Neon64(a, b, c, bias, params_->act_type_, params_->deep_, col_step, params_->col_align_);
#elif defined(ENABLE_ARM32)
MatVecMulFp32Block4(a, b, c, bias, params_->act_type_, params_->deep_, col_step);
#else
MatVecMulFp32Block8(a, b, c, bias, params_->act_type_, params_->deep_, col_step);
#endif
} else {
#ifdef ENABLE_AVX
MatMulAvxFp32(a, b, c, bias, params_->act_type_, params_->deep_, col_step, params_->col_align_, params_->row_);
#else
MatMulOpt(a, b, c, bias, params_->act_type_, params_->deep_, params_->row_, col_step, params_->col_,
OutType_Nhwc);
#endif
}
}
return RET_OK;
}
int MatmulFp32BaseCPUKernel::ParallelRunByOC(int task_id) const {
int current_start_oc = task_id * oc_stride_ * col_tile_;
#if defined(ENABLE_AVX)
int current_rest_oc = params_->col_align_ - current_start_oc;
#else
int current_rest_oc = params_->col_ - current_start_oc;
#endif
int cur_oc = MSMIN(oc_stride_ * col_tile_, current_rest_oc);
if (cur_oc <= 0) {
return RET_OK;
}
@ -226,7 +263,7 @@ int MatmulFp32BaseCPUKernel::FloatRun(int task_id) const {
#ifdef ENABLE_AVX
MatVecMulAvxFp32(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc, params_->col_align_);
#elif defined(ENABLE_ARM64)
int rest_align_col = MSMIN(params_->col_align_ - current_start_oc, thread_stride_ * col_tile_);
int rest_align_col = MSMIN(params_->col_align_ - current_start_oc, oc_stride_ * col_tile_);
MatVecMulFp32Neon64(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc, rest_align_col);
#elif defined(ENABLE_ARM32)
MatVecMulFp32Block4(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc);
@ -323,13 +360,7 @@ int MatmulFp32BaseCPUKernel::ReSize() {
if (op_parameter_->is_train_session_) {
set_workspace_size((matrix_a_pack_size_ + matrix_b_pack_size_) * static_cast<int>(sizeof(float)));
}
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(params_->col_align_, col_tile_));
#if defined(ENABLE_AVX) // thread tile by col_tile * C4NUM
thread_stride_ = UP_DIV(UP_DIV(params_->col_align_, col_tile_ * C4NUM), thread_count_) * C4NUM;
#else
thread_stride_ = UP_DIV(UP_DIV(params_->col_align_, col_tile_), thread_count_);
#endif
GetThreadCuttingPolicy();
return RET_OK;
}
@ -368,6 +399,24 @@ int MatmulFp32BaseCPUKernel::InitTmpOutBuffer() {
return RET_OK;
}
void MatmulFp32BaseCPUKernel::GetThreadCuttingPolicy() {
if (params_->batch >= op_parameter_->thread_num_) {
thread_count_ = op_parameter_->thread_num_;
batch_stride_ = UP_DIV(params_->batch, thread_count_);
batch_split_ = true;
parallel_fun_ = &MatmulFp32BaseCPUKernel::ParallelRunByBatch;
} else {
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(params_->col_align_, col_tile_));
#if defined(ENABLE_AVX) // thread tile by col_tile * C4NUM
oc_stride_ = UP_DIV(UP_DIV(params_->col_align_, col_tile_ * C4NUM), thread_count_) * C4NUM;
#else
oc_stride_ = UP_DIV(UP_DIV(params_->col_align_, col_tile_), thread_count_);
#endif
batch_split_ = false;
parallel_fun_ = &MatmulFp32BaseCPUKernel::ParallelRunByOC;
}
}
int MatmulFp32BaseCPUKernel::Run() {
if (!params_->a_const_) {
auto a_ptr = reinterpret_cast<float *>(in_tensors_[0]->data());
@ -403,20 +452,28 @@ int MatmulFp32BaseCPUKernel::Run() {
return ret;
}
for (int i = 0; i < params_->batch; ++i) {
batch_a_ptr_ = a_pack_ptr_ + a_offset_[i] * params_->row_align_ * params_->deep_;
batch_b_ptr_ = b_pack_ptr_ + b_offset_[i] * params_->deep_ * params_->col_align_;
if (batch_split_) {
ret = ParallelLaunch(this->ms_context_, MatmulRun, this, thread_count_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "MatmulRun failed in split by batch";
return ret;
}
} else {
#ifdef ENABLE_AVX
batch_c_ptr_ = output_data_ + i * params_->row_ * params_->col_align_;
int col_step = params_->col_align_;
#else
// need not aligned
batch_c_ptr_ = output_data_ + i * params_->row_ * params_->col_;
int col_step = params_->col_;
#endif
ret = ParallelLaunch(this->ms_context_, MatmulBaseFloatRun, this, thread_count_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "MatmulBaseFloatRun failed";
return ret;
for (int i = 0; i < params_->batch; ++i) {
batch_a_ptr_ = a_pack_ptr_ + a_offset_[i] * params_->row_align_ * params_->deep_;
batch_b_ptr_ = b_pack_ptr_ + b_offset_[i] * params_->deep_ * params_->col_align_;
batch_c_ptr_ = output_data_ + i * params_->row_ * col_step;
ret = ParallelLaunch(this->ms_context_, MatmulRun, this, thread_count_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "MatmulRun failed in split by oc";
return ret;
}
}
}

View File

@ -43,7 +43,10 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
int Run() override;
public:
int FloatRun(int task_id) const;
int ParallelRunByOC(int task_id) const;
int ParallelRunByBatch(int task_id) const;
using ParallelRun = int (MatmulFp32BaseCPUKernel::*)(int task_id) const;
ParallelRun parallel_fun_ = nullptr;
protected:
int InitBufferA();
@ -61,6 +64,7 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
void FreeResizeBufB();
int CalBroadCastBiasDataElements();
int InitTmpOutBuffer();
void GetThreadCuttingPolicy();
protected:
MatMulParameter *params_ = nullptr;
@ -75,7 +79,8 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
int col_tile_ = 0;
int row_tile_ = 0;
int oc_res_ = 0;
int thread_stride_ = 0;
int batch_stride_ = 0;
int oc_stride_ = 0;
int thread_count_ = 0;
bool vec_matmul_ = false;
float *bias_ptr_ = nullptr;
@ -87,6 +92,7 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
int matrix_b_pack_size_ = -1;
MatrixPackFun matrix_a_pack_fun_ = nullptr;
MatrixPackFun matrix_b_pack_fun_ = nullptr;
bool batch_split_ = false;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_MATMUL_FP32_BASE_H_