!27907 [MS][LITE][CPU] matmul nonconstant weight optimize

Merge pull request !27907 from liuzhongkai/matmul512_new
This commit is contained in:
i-robot 2021-12-22 09:56:27 +00:00 committed by Gitee
commit fde62142a0
5 changed files with 128 additions and 46 deletions

View File

@ -169,3 +169,4 @@ mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_avx512_f
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_avx512_fp32.c:nnacl_gemm_avx512_4x64_kernel_nhwc_fp32
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_avx512_fp32.c:nnacl_gemm_avx512_5x64_kernel_nhwc_fp32
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_avx512_fp32.c:nnacl_gemm_avx512_6x64_kernel_nhwc_fp32
mindspore/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc:mindspore::kernel::MatmulFp32BaseCPUKernel::Run

View File

@ -2172,3 +2172,44 @@ void MatVecMulRowxColKernel(float *dst, const float *src, const float *weight, c
}
#endif
#endif
void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row) {
int index = 0;
#ifdef ENABLE_AVX512
__m512 b_data16 = _mm512_set1_ps(b[0]);
__m512 bias_data16 = _mm512_set1_ps(bias[0]);
#endif
#ifdef ENABLE_AVX
__m256 b_data8 = _mm256_set1_ps(b[0]);
__m256 bias_data8 = _mm256_set1_ps(bias[0]);
#endif
#if defined(ENABLE_SSE) || defined(ENABLE_ARM)
MS_FLOAT32X4 b_data4 = MS_MOVQ_F32(b[0]);
MS_FLOAT32X4 bias_data4 = MS_MOVQ_F32(bias[0]);
#endif
#ifdef ENABLE_AVX512
for (; index < row - C16NUM; index += C16NUM) {
__m512 a_data = _mm512_loadu_ps(a + index);
_mm512_storeu_ps(c + index, b_data16 * a_data + bias_data16);
}
#endif
#ifdef ENABLE_AVX
for (; index < row - C8NUM; index += C8NUM) {
__m256 a_data = _mm256_loadu_ps(a + index);
_mm256_storeu_ps(c + index, b_data8 * a_data + bias_data8);
}
#endif
#if defined(ENABLE_SSE) || defined(ENABLE_ARM)
for (; index < row - C4NUM; index += C4NUM) {
MS_FLOAT32X4 a_data = MS_LDQ_F32(a + index);
MS_STQ_F32(c + index, b_data4 * a_data + bias_data4);
}
#endif
for (; index < row; ++index) {
c[index] = a[index] * b[0] + bias[0];
}
}

View File

@ -125,6 +125,8 @@ void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *
void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
int col, int stride, int out_type);
void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row);
#ifdef __cplusplus
}
#endif

View File

@ -28,17 +28,28 @@ namespace mindspore::kernel {
int MatmulRun(const void *cdata, int task_id, float, float) {
CHECK_NULL_RETURN(cdata);
auto op = reinterpret_cast<const MatmulFp32BaseCPUKernel *>(cdata);
auto error_code = (op->*(op->parallel_fun_))(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "MatmulRun error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
if (op->is_pack_) {
auto error_code = (op->*(op->parallel_fun_))(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "MatmulRun error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
}
} else {
auto error_code = op->ParallelRunIsNotPackByBatch(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "MatmulRun error task_id[" << task_id << "] error_code[" << error_code
<< "] in ParallelRunIsNotPackByBatch";
return RET_ERROR;
}
}
return RET_OK;
}
MatmulFp32BaseCPUKernel::~MatmulFp32BaseCPUKernel() {
FreeResizeBufA();
FreeResizeBufB();
if (is_pack_) {
FreeResizeBufA();
FreeResizeBufB();
}
FreeBiasBuf();
}
@ -188,14 +199,14 @@ void MatmulFp32BaseCPUKernel::FreeBiasBuf() {
}
void MatmulFp32BaseCPUKernel::FreeResizeBufA() {
if (!vec_matmul_ && !op_parameter_->is_train_session_ && a_pack_ptr_ != nullptr) {
if (!vec_matmul_ && !op_parameter_->is_train_session_ && a_pack_ptr_ != nullptr && is_pack_) {
ms_context_->allocator->Free(a_pack_ptr_);
}
a_pack_ptr_ = nullptr;
}
void MatmulFp32BaseCPUKernel::FreeResizeBufB() {
if (!op_parameter_->is_train_session_ && b_pack_ptr_ != nullptr) {
if (!op_parameter_->is_train_session_ && b_pack_ptr_ != nullptr && is_pack_) {
ms_context_->allocator->Free(b_pack_ptr_);
}
b_pack_ptr_ = nullptr;
@ -239,6 +250,22 @@ int MatmulFp32BaseCPUKernel::ParallelRunByBatch(int task_id) const {
return RET_OK;
}
int MatmulFp32BaseCPUKernel::ParallelRunIsNotPackByBatch(int task_id) const {
int start_batch = task_id * batch_stride_;
int end_batch = MSMIN(params_->batch, start_batch + batch_stride_);
float bias = 0;
if (bias_ptr_ != nullptr) {
bias = bias_ptr_[0];
}
for (int index = start_batch; index < end_batch; ++index) {
const float *a = a_pack_ptr_ + index * params_->row_ * params_->deep_;
const float *b = b_pack_ptr_ + index * params_->deep_ * params_->col_;
float *c = output_data_ + index * params_->row_ * params_->col_;
GemmIsNotPack(a, b, c, &bias, params_->row_);
}
return RET_OK;
}
int MatmulFp32BaseCPUKernel::ParallelRunByOC(int task_id) const {
int current_start_oc = task_id * oc_stride_ * col_tile_;
#if defined(ENABLE_AVX) || defined(ENABLE_AVX512)
@ -392,23 +419,27 @@ void MatmulFp32BaseCPUKernel::ResizeParameter() {
int MatmulFp32BaseCPUKernel::InitTmpOutBuffer() {
auto out_data = reinterpret_cast<float *>(out_tensors_.front()->data());
MS_ASSERT(out_data != nullptr);
#if defined(ENABLE_AVX) || defined(ENABLE_AVX512)
if (oc_res_ != 0) { // avx matmul need to malloc dst aligned to C8NUM and avx512 need to aligned to C16NUM
int out_channel = params_->col_;
int oc_block_num = UP_DIV(out_channel, col_tile_);
MS_ASSERT(ms_context_->allocator != nullptr);
output_data_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(
params_->batch * params_->row_ * oc_block_num * col_tile_ * static_cast<int>(sizeof(float))));
if (output_data_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp output data failed.";
return RET_NULL_PTR;
}
} else { // need to malloc dst to algin block
if (!is_pack_) {
output_data_ = out_data;
}
} else {
#if defined(ENABLE_AVX) || defined(ENABLE_AVX512)
if (oc_res_ != 0) { // avx matmul need to malloc dst aligned to C8NUM
int out_channel = params_->col_;
int oc_block_num = UP_DIV(out_channel, col_tile_);
MS_ASSERT(ms_context_->allocator != nullptr);
output_data_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(
params_->batch * params_->row_ * oc_block_num * col_tile_ * static_cast<int>(sizeof(float))));
if (output_data_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp output data failed.";
return RET_NULL_PTR;
}
} else { // need to malloc dst to algin block
output_data_ = out_data;
}
#else
output_data_ = out_data;
output_data_ = out_data;
#endif
}
return RET_OK;
}
@ -431,41 +462,46 @@ void MatmulFp32BaseCPUKernel::GetThreadCuttingPolicy() {
}
int MatmulFp32BaseCPUKernel::Run() {
if (!params_->a_const_) {
auto a_ptr = reinterpret_cast<float *>(in_tensors_[0]->data());
CHECK_NULL_RETURN(a_ptr);
if (RET_OK != InitBufferA()) {
return RET_ERROR;
}
auto ret = InitMatrixA(a_ptr);
if (ret != RET_OK) {
MS_LOG(ERROR) << "InitMatrixA failed!";
return ret;
}
}
if (!params_->b_const_) {
auto b_ptr = reinterpret_cast<float *>(in_tensors_[1]->data());
CHECK_NULL_RETURN(b_ptr);
if (RET_OK != InitBufferB()) {
FreeResizeBufA();
return RET_ERROR;
if (params_->col_ == 1 && params_->deep_ == 1) {
b_pack_ptr_ = b_ptr;
is_pack_ = false;
} else {
if (RET_OK != InitBufferB()) {
return RET_ERROR;
}
auto ret = InitMatrixB(b_ptr);
if (ret != RET_OK) {
MS_LOG(ERROR) << "InitMatrixB failed!";
return ret;
}
}
auto ret = InitMatrixB(b_ptr);
if (ret != RET_OK) {
MS_LOG(ERROR) << "InitMatrixB failed!";
return ret;
}
if (!params_->a_const_) {
auto a_ptr = reinterpret_cast<float *>(in_tensors_[0]->data());
CHECK_NULL_RETURN(a_ptr);
if (!is_pack_) {
a_pack_ptr_ = a_ptr;
} else {
if (RET_OK != InitBufferA()) {
return RET_ERROR;
}
auto ret = InitMatrixA(a_ptr);
if (ret != RET_OK) {
MS_LOG(ERROR) << "InitMatrixA failed!";
return ret;
}
}
}
auto ret = InitTmpOutBuffer();
if (ret != RET_OK) {
FreeResizeBufA();
FreeResizeBufB();
MS_LOG(ERROR) << "InitTmpOutBuffer error!";
return ret;
}
if (batch_split_) {
if (batch_split_ || !is_pack_) {
ret = ParallelLaunch(this->ms_context_, MatmulRun, this, thread_count_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "MatmulRun failed in split by batch";
@ -491,7 +527,7 @@ int MatmulFp32BaseCPUKernel::Run() {
}
#if defined(ENABLE_AVX) || defined(ENABLE_AVX512)
if (oc_res_ != 0) {
if (oc_res_ != 0 && is_pack_) {
auto out_data = reinterpret_cast<float *>(out_tensors_.front()->MutableData());
PackNHWCXToNHWCFp32(output_data_, out_data, params_->batch, params_->row_, params_->col_, col_tile_);
ms_context_->allocator->Free(output_data_);

View File

@ -49,8 +49,10 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
public:
int ParallelRunByOC(int task_id) const;
int ParallelRunByBatch(int task_id) const;
int ParallelRunIsNotPackByBatch(int task_id) const;
using ParallelRun = int (MatmulFp32BaseCPUKernel::*)(int task_id) const;
ParallelRun parallel_fun_ = nullptr;
bool is_pack_ = true;
protected:
int InitBufferA();