forked from mindspore-Ecosystem/mindspore
!26445 [MS][LITE][CPU] thread split op
Merge pull request !26445 from liuzhongkai/matmul
This commit is contained in:
commit
2f709641cd
|
@ -22,12 +22,12 @@
|
|||
using mindspore::lite::RET_NULL_PTR;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int MatmulBaseFloatRun(const void *cdata, int task_id, float, float) {
|
||||
int MatmulRun(const void *cdata, int task_id, float, float) {
|
||||
CHECK_NULL_RETURN(cdata);
|
||||
auto op = reinterpret_cast<const MatmulFp32BaseCPUKernel *>(cdata);
|
||||
auto error_code = op->FloatRun(task_id);
|
||||
auto error_code = (op->*(op->parallel_fun_))(task_id);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "MatmulFp32Run error task_id[" << task_id << "] error_code[" << error_code << "]";
|
||||
MS_LOG(ERROR) << "MatmulRun error task_id[" << task_id << "] error_code[" << error_code << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
|
@ -116,7 +116,7 @@ int MatmulFp32BaseCPUKernel::InitBiasData() {
|
|||
MS_LOG(ERROR) << "bias_tensor invalid";
|
||||
return RET_ERROR;
|
||||
}
|
||||
size_t bias_num = static_cast<size_t>(bias_tensor->ElementsNum());
|
||||
auto bias_num = static_cast<size_t>(bias_tensor->ElementsNum());
|
||||
MS_CHECK_TRUE_RET(bias_num > 0, RET_ERROR);
|
||||
if (bias_num == 1) {
|
||||
// broadcast bias data
|
||||
|
@ -134,7 +134,7 @@ int MatmulFp32BaseCPUKernel::InitBiasData() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
size_t max_bias_data = static_cast<size_t>(UP_ROUND(bias_num, col_tile_));
|
||||
auto max_bias_data = static_cast<size_t>(UP_ROUND(bias_num, col_tile_));
|
||||
// malloc addr need to aligned to 32 bytes
|
||||
bias_ptr_ = reinterpret_cast<float *>(malloc(max_bias_data * static_cast<int>(sizeof(float))));
|
||||
if (bias_ptr_ == nullptr) {
|
||||
|
@ -206,15 +206,52 @@ void MatmulFp32BaseCPUKernel::FreeResizeBufB() {
|
|||
}
|
||||
}
|
||||
|
||||
int MatmulFp32BaseCPUKernel::FloatRun(int task_id) const {
|
||||
int current_start_oc = task_id * thread_stride_ * col_tile_;
|
||||
int current_rest_oc = 0;
|
||||
#if defined(ENABLE_AVX)
|
||||
current_rest_oc = params_->col_align_ - current_start_oc;
|
||||
int MatmulFp32BaseCPUKernel::ParallelRunByBatch(int task_id) const {
|
||||
int start_batch = task_id * batch_stride_;
|
||||
int end_batch = MSMIN(params_->batch, start_batch + batch_stride_);
|
||||
#ifdef ENABLE_AVX
|
||||
int col_step = params_->col_align_;
|
||||
#else
|
||||
current_rest_oc = params_->col_ - current_start_oc;
|
||||
// col need not aligned
|
||||
int col_step = params_->col_;
|
||||
#endif
|
||||
int cur_oc = MSMIN(thread_stride_ * col_tile_, current_rest_oc);
|
||||
|
||||
for (int index = start_batch; index < end_batch; ++index) {
|
||||
const float *a = a_pack_ptr_ + a_offset_[index] * params_->row_align_ * params_->deep_;
|
||||
const float *b = b_pack_ptr_ + b_offset_[index] * params_->deep_ * params_->col_align_;
|
||||
float *c = output_data_ + index * params_->row_ * col_step;
|
||||
|
||||
auto bias = (bias_ptr_ == nullptr) ? nullptr : bias_ptr_;
|
||||
if (vec_matmul_) {
|
||||
#ifdef ENABLE_AVX
|
||||
MatVecMulAvxFp32(a, b, c, bias, params_->act_type_, params_->deep_, col_step, params_->col_align_);
|
||||
#elif defined(ENABLE_ARM64)
|
||||
MatVecMulFp32Neon64(a, b, c, bias, params_->act_type_, params_->deep_, col_step, params_->col_align_);
|
||||
#elif defined(ENABLE_ARM32)
|
||||
MatVecMulFp32Block4(a, b, c, bias, params_->act_type_, params_->deep_, col_step);
|
||||
#else
|
||||
MatVecMulFp32Block8(a, b, c, bias, params_->act_type_, params_->deep_, col_step);
|
||||
#endif
|
||||
} else {
|
||||
#ifdef ENABLE_AVX
|
||||
MatMulAvxFp32(a, b, c, bias, params_->act_type_, params_->deep_, col_step, params_->col_align_, params_->row_);
|
||||
#else
|
||||
MatMulOpt(a, b, c, bias, params_->act_type_, params_->deep_, params_->row_, col_step, params_->col_,
|
||||
OutType_Nhwc);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatmulFp32BaseCPUKernel::ParallelRunByOC(int task_id) const {
|
||||
int current_start_oc = task_id * oc_stride_ * col_tile_;
|
||||
#if defined(ENABLE_AVX)
|
||||
int current_rest_oc = params_->col_align_ - current_start_oc;
|
||||
#else
|
||||
int current_rest_oc = params_->col_ - current_start_oc;
|
||||
#endif
|
||||
int cur_oc = MSMIN(oc_stride_ * col_tile_, current_rest_oc);
|
||||
if (cur_oc <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -226,7 +263,7 @@ int MatmulFp32BaseCPUKernel::FloatRun(int task_id) const {
|
|||
#ifdef ENABLE_AVX
|
||||
MatVecMulAvxFp32(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc, params_->col_align_);
|
||||
#elif defined(ENABLE_ARM64)
|
||||
int rest_align_col = MSMIN(params_->col_align_ - current_start_oc, thread_stride_ * col_tile_);
|
||||
int rest_align_col = MSMIN(params_->col_align_ - current_start_oc, oc_stride_ * col_tile_);
|
||||
MatVecMulFp32Neon64(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc, rest_align_col);
|
||||
#elif defined(ENABLE_ARM32)
|
||||
MatVecMulFp32Block4(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc);
|
||||
|
@ -323,13 +360,7 @@ int MatmulFp32BaseCPUKernel::ReSize() {
|
|||
if (op_parameter_->is_train_session_) {
|
||||
set_workspace_size((matrix_a_pack_size_ + matrix_b_pack_size_) * static_cast<int>(sizeof(float)));
|
||||
}
|
||||
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(params_->col_align_, col_tile_));
|
||||
#if defined(ENABLE_AVX) // thread tile by col_tile * C4NUM
|
||||
thread_stride_ = UP_DIV(UP_DIV(params_->col_align_, col_tile_ * C4NUM), thread_count_) * C4NUM;
|
||||
#else
|
||||
thread_stride_ = UP_DIV(UP_DIV(params_->col_align_, col_tile_), thread_count_);
|
||||
#endif
|
||||
|
||||
GetThreadCuttingPolicy();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -368,6 +399,24 @@ int MatmulFp32BaseCPUKernel::InitTmpOutBuffer() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
void MatmulFp32BaseCPUKernel::GetThreadCuttingPolicy() {
|
||||
if (params_->batch >= op_parameter_->thread_num_) {
|
||||
thread_count_ = op_parameter_->thread_num_;
|
||||
batch_stride_ = UP_DIV(params_->batch, thread_count_);
|
||||
batch_split_ = true;
|
||||
parallel_fun_ = &MatmulFp32BaseCPUKernel::ParallelRunByBatch;
|
||||
} else {
|
||||
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(params_->col_align_, col_tile_));
|
||||
#if defined(ENABLE_AVX) // thread tile by col_tile * C4NUM
|
||||
oc_stride_ = UP_DIV(UP_DIV(params_->col_align_, col_tile_ * C4NUM), thread_count_) * C4NUM;
|
||||
#else
|
||||
oc_stride_ = UP_DIV(UP_DIV(params_->col_align_, col_tile_), thread_count_);
|
||||
#endif
|
||||
batch_split_ = false;
|
||||
parallel_fun_ = &MatmulFp32BaseCPUKernel::ParallelRunByOC;
|
||||
}
|
||||
}
|
||||
|
||||
int MatmulFp32BaseCPUKernel::Run() {
|
||||
if (!params_->a_const_) {
|
||||
auto a_ptr = reinterpret_cast<float *>(in_tensors_[0]->data());
|
||||
|
@ -403,20 +452,28 @@ int MatmulFp32BaseCPUKernel::Run() {
|
|||
return ret;
|
||||
}
|
||||
|
||||
for (int i = 0; i < params_->batch; ++i) {
|
||||
batch_a_ptr_ = a_pack_ptr_ + a_offset_[i] * params_->row_align_ * params_->deep_;
|
||||
batch_b_ptr_ = b_pack_ptr_ + b_offset_[i] * params_->deep_ * params_->col_align_;
|
||||
if (batch_split_) {
|
||||
ret = ParallelLaunch(this->ms_context_, MatmulRun, this, thread_count_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "MatmulRun failed in split by batch";
|
||||
return ret;
|
||||
}
|
||||
} else {
|
||||
#ifdef ENABLE_AVX
|
||||
batch_c_ptr_ = output_data_ + i * params_->row_ * params_->col_align_;
|
||||
int col_step = params_->col_align_;
|
||||
#else
|
||||
// need not aligned
|
||||
batch_c_ptr_ = output_data_ + i * params_->row_ * params_->col_;
|
||||
int col_step = params_->col_;
|
||||
#endif
|
||||
|
||||
ret = ParallelLaunch(this->ms_context_, MatmulBaseFloatRun, this, thread_count_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "MatmulBaseFloatRun failed";
|
||||
return ret;
|
||||
for (int i = 0; i < params_->batch; ++i) {
|
||||
batch_a_ptr_ = a_pack_ptr_ + a_offset_[i] * params_->row_align_ * params_->deep_;
|
||||
batch_b_ptr_ = b_pack_ptr_ + b_offset_[i] * params_->deep_ * params_->col_align_;
|
||||
batch_c_ptr_ = output_data_ + i * params_->row_ * col_step;
|
||||
ret = ParallelLaunch(this->ms_context_, MatmulRun, this, thread_count_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "MatmulRun failed in split by oc";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -43,7 +43,10 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
|
|||
int Run() override;
|
||||
|
||||
public:
|
||||
int FloatRun(int task_id) const;
|
||||
int ParallelRunByOC(int task_id) const;
|
||||
int ParallelRunByBatch(int task_id) const;
|
||||
using ParallelRun = int (MatmulFp32BaseCPUKernel::*)(int task_id) const;
|
||||
ParallelRun parallel_fun_ = nullptr;
|
||||
|
||||
protected:
|
||||
int InitBufferA();
|
||||
|
@ -61,6 +64,7 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
|
|||
void FreeResizeBufB();
|
||||
int CalBroadCastBiasDataElements();
|
||||
int InitTmpOutBuffer();
|
||||
void GetThreadCuttingPolicy();
|
||||
|
||||
protected:
|
||||
MatMulParameter *params_ = nullptr;
|
||||
|
@ -75,7 +79,8 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
|
|||
int col_tile_ = 0;
|
||||
int row_tile_ = 0;
|
||||
int oc_res_ = 0;
|
||||
int thread_stride_ = 0;
|
||||
int batch_stride_ = 0;
|
||||
int oc_stride_ = 0;
|
||||
int thread_count_ = 0;
|
||||
bool vec_matmul_ = false;
|
||||
float *bias_ptr_ = nullptr;
|
||||
|
@ -87,6 +92,7 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
|
|||
int matrix_b_pack_size_ = -1;
|
||||
MatrixPackFun matrix_a_pack_fun_ = nullptr;
|
||||
MatrixPackFun matrix_b_pack_fun_ = nullptr;
|
||||
bool batch_split_ = false;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_MATMUL_FP32_BASE_H_
|
||||
|
|
Loading…
Reference in New Issue