forked from mindspore-Ecosystem/mindspore
!27907 [MS][LITE][CPU] matmul nonconstant weight optimize
Merge pull request !27907 from liuzhongkai/matmul512_new
This commit is contained in:
commit
fde62142a0
|
@ -169,3 +169,4 @@ mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_avx512_f
|
|||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_avx512_fp32.c:nnacl_gemm_avx512_4x64_kernel_nhwc_fp32
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_avx512_fp32.c:nnacl_gemm_avx512_5x64_kernel_nhwc_fp32
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_avx512_fp32.c:nnacl_gemm_avx512_6x64_kernel_nhwc_fp32
|
||||
mindspore/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc:mindspore::kernel::MatmulFp32BaseCPUKernel::Run
|
||||
|
|
|
@ -2172,3 +2172,44 @@ void MatVecMulRowxColKernel(float *dst, const float *src, const float *weight, c
|
|||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row) {
|
||||
int index = 0;
|
||||
#ifdef ENABLE_AVX512
|
||||
__m512 b_data16 = _mm512_set1_ps(b[0]);
|
||||
__m512 bias_data16 = _mm512_set1_ps(bias[0]);
|
||||
#endif
|
||||
#ifdef ENABLE_AVX
|
||||
__m256 b_data8 = _mm256_set1_ps(b[0]);
|
||||
__m256 bias_data8 = _mm256_set1_ps(bias[0]);
|
||||
#endif
|
||||
#if defined(ENABLE_SSE) || defined(ENABLE_ARM)
|
||||
MS_FLOAT32X4 b_data4 = MS_MOVQ_F32(b[0]);
|
||||
MS_FLOAT32X4 bias_data4 = MS_MOVQ_F32(bias[0]);
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_AVX512
|
||||
for (; index < row - C16NUM; index += C16NUM) {
|
||||
__m512 a_data = _mm512_loadu_ps(a + index);
|
||||
_mm512_storeu_ps(c + index, b_data16 * a_data + bias_data16);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_AVX
|
||||
for (; index < row - C8NUM; index += C8NUM) {
|
||||
__m256 a_data = _mm256_loadu_ps(a + index);
|
||||
_mm256_storeu_ps(c + index, b_data8 * a_data + bias_data8);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_SSE) || defined(ENABLE_ARM)
|
||||
for (; index < row - C4NUM; index += C4NUM) {
|
||||
MS_FLOAT32X4 a_data = MS_LDQ_F32(a + index);
|
||||
MS_STQ_F32(c + index, b_data4 * a_data + bias_data4);
|
||||
}
|
||||
#endif
|
||||
|
||||
for (; index < row; ++index) {
|
||||
c[index] = a[index] * b[0] + bias[0];
|
||||
}
|
||||
}
|
||||
|
|
|
@ -125,6 +125,8 @@ void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *
|
|||
void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
|
||||
int col, int stride, int out_type);
|
||||
|
||||
void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -28,17 +28,28 @@ namespace mindspore::kernel {
|
|||
int MatmulRun(const void *cdata, int task_id, float, float) {
|
||||
CHECK_NULL_RETURN(cdata);
|
||||
auto op = reinterpret_cast<const MatmulFp32BaseCPUKernel *>(cdata);
|
||||
auto error_code = (op->*(op->parallel_fun_))(task_id);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "MatmulRun error task_id[" << task_id << "] error_code[" << error_code << "]";
|
||||
return RET_ERROR;
|
||||
if (op->is_pack_) {
|
||||
auto error_code = (op->*(op->parallel_fun_))(task_id);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "MatmulRun error task_id[" << task_id << "] error_code[" << error_code << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
auto error_code = op->ParallelRunIsNotPackByBatch(task_id);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "MatmulRun error task_id[" << task_id << "] error_code[" << error_code
|
||||
<< "] in ParallelRunIsNotPackByBatch";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
MatmulFp32BaseCPUKernel::~MatmulFp32BaseCPUKernel() {
|
||||
FreeResizeBufA();
|
||||
FreeResizeBufB();
|
||||
if (is_pack_) {
|
||||
FreeResizeBufA();
|
||||
FreeResizeBufB();
|
||||
}
|
||||
FreeBiasBuf();
|
||||
}
|
||||
|
||||
|
@ -188,14 +199,14 @@ void MatmulFp32BaseCPUKernel::FreeBiasBuf() {
|
|||
}
|
||||
|
||||
void MatmulFp32BaseCPUKernel::FreeResizeBufA() {
|
||||
if (!vec_matmul_ && !op_parameter_->is_train_session_ && a_pack_ptr_ != nullptr) {
|
||||
if (!vec_matmul_ && !op_parameter_->is_train_session_ && a_pack_ptr_ != nullptr && is_pack_) {
|
||||
ms_context_->allocator->Free(a_pack_ptr_);
|
||||
}
|
||||
a_pack_ptr_ = nullptr;
|
||||
}
|
||||
|
||||
void MatmulFp32BaseCPUKernel::FreeResizeBufB() {
|
||||
if (!op_parameter_->is_train_session_ && b_pack_ptr_ != nullptr) {
|
||||
if (!op_parameter_->is_train_session_ && b_pack_ptr_ != nullptr && is_pack_) {
|
||||
ms_context_->allocator->Free(b_pack_ptr_);
|
||||
}
|
||||
b_pack_ptr_ = nullptr;
|
||||
|
@ -239,6 +250,22 @@ int MatmulFp32BaseCPUKernel::ParallelRunByBatch(int task_id) const {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatmulFp32BaseCPUKernel::ParallelRunIsNotPackByBatch(int task_id) const {
|
||||
int start_batch = task_id * batch_stride_;
|
||||
int end_batch = MSMIN(params_->batch, start_batch + batch_stride_);
|
||||
float bias = 0;
|
||||
if (bias_ptr_ != nullptr) {
|
||||
bias = bias_ptr_[0];
|
||||
}
|
||||
for (int index = start_batch; index < end_batch; ++index) {
|
||||
const float *a = a_pack_ptr_ + index * params_->row_ * params_->deep_;
|
||||
const float *b = b_pack_ptr_ + index * params_->deep_ * params_->col_;
|
||||
float *c = output_data_ + index * params_->row_ * params_->col_;
|
||||
GemmIsNotPack(a, b, c, &bias, params_->row_);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatmulFp32BaseCPUKernel::ParallelRunByOC(int task_id) const {
|
||||
int current_start_oc = task_id * oc_stride_ * col_tile_;
|
||||
#if defined(ENABLE_AVX) || defined(ENABLE_AVX512)
|
||||
|
@ -392,23 +419,27 @@ void MatmulFp32BaseCPUKernel::ResizeParameter() {
|
|||
int MatmulFp32BaseCPUKernel::InitTmpOutBuffer() {
|
||||
auto out_data = reinterpret_cast<float *>(out_tensors_.front()->data());
|
||||
MS_ASSERT(out_data != nullptr);
|
||||
#if defined(ENABLE_AVX) || defined(ENABLE_AVX512)
|
||||
if (oc_res_ != 0) { // avx matmul need to malloc dst aligned to C8NUM and avx512 need to aligned to C16NUM
|
||||
int out_channel = params_->col_;
|
||||
int oc_block_num = UP_DIV(out_channel, col_tile_);
|
||||
MS_ASSERT(ms_context_->allocator != nullptr);
|
||||
output_data_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(
|
||||
params_->batch * params_->row_ * oc_block_num * col_tile_ * static_cast<int>(sizeof(float))));
|
||||
if (output_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc tmp output data failed.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
} else { // need to malloc dst to algin block
|
||||
if (!is_pack_) {
|
||||
output_data_ = out_data;
|
||||
}
|
||||
} else {
|
||||
#if defined(ENABLE_AVX) || defined(ENABLE_AVX512)
|
||||
if (oc_res_ != 0) { // avx matmul need to malloc dst aligned to C8NUM
|
||||
int out_channel = params_->col_;
|
||||
int oc_block_num = UP_DIV(out_channel, col_tile_);
|
||||
MS_ASSERT(ms_context_->allocator != nullptr);
|
||||
output_data_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(
|
||||
params_->batch * params_->row_ * oc_block_num * col_tile_ * static_cast<int>(sizeof(float))));
|
||||
if (output_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc tmp output data failed.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
} else { // need to malloc dst to algin block
|
||||
output_data_ = out_data;
|
||||
}
|
||||
#else
|
||||
output_data_ = out_data;
|
||||
output_data_ = out_data;
|
||||
#endif
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -431,41 +462,46 @@ void MatmulFp32BaseCPUKernel::GetThreadCuttingPolicy() {
|
|||
}
|
||||
|
||||
int MatmulFp32BaseCPUKernel::Run() {
|
||||
if (!params_->a_const_) {
|
||||
auto a_ptr = reinterpret_cast<float *>(in_tensors_[0]->data());
|
||||
CHECK_NULL_RETURN(a_ptr);
|
||||
if (RET_OK != InitBufferA()) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto ret = InitMatrixA(a_ptr);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "InitMatrixA failed!";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
if (!params_->b_const_) {
|
||||
auto b_ptr = reinterpret_cast<float *>(in_tensors_[1]->data());
|
||||
CHECK_NULL_RETURN(b_ptr);
|
||||
if (RET_OK != InitBufferB()) {
|
||||
FreeResizeBufA();
|
||||
return RET_ERROR;
|
||||
if (params_->col_ == 1 && params_->deep_ == 1) {
|
||||
b_pack_ptr_ = b_ptr;
|
||||
is_pack_ = false;
|
||||
} else {
|
||||
if (RET_OK != InitBufferB()) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto ret = InitMatrixB(b_ptr);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "InitMatrixB failed!";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
auto ret = InitMatrixB(b_ptr);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "InitMatrixB failed!";
|
||||
return ret;
|
||||
}
|
||||
if (!params_->a_const_) {
|
||||
auto a_ptr = reinterpret_cast<float *>(in_tensors_[0]->data());
|
||||
CHECK_NULL_RETURN(a_ptr);
|
||||
if (!is_pack_) {
|
||||
a_pack_ptr_ = a_ptr;
|
||||
} else {
|
||||
if (RET_OK != InitBufferA()) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto ret = InitMatrixA(a_ptr);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "InitMatrixA failed!";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto ret = InitTmpOutBuffer();
|
||||
if (ret != RET_OK) {
|
||||
FreeResizeBufA();
|
||||
FreeResizeBufB();
|
||||
MS_LOG(ERROR) << "InitTmpOutBuffer error!";
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (batch_split_) {
|
||||
if (batch_split_ || !is_pack_) {
|
||||
ret = ParallelLaunch(this->ms_context_, MatmulRun, this, thread_count_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "MatmulRun failed in split by batch";
|
||||
|
@ -491,7 +527,7 @@ int MatmulFp32BaseCPUKernel::Run() {
|
|||
}
|
||||
|
||||
#if defined(ENABLE_AVX) || defined(ENABLE_AVX512)
|
||||
if (oc_res_ != 0) {
|
||||
if (oc_res_ != 0 && is_pack_) {
|
||||
auto out_data = reinterpret_cast<float *>(out_tensors_.front()->MutableData());
|
||||
PackNHWCXToNHWCFp32(output_data_, out_data, params_->batch, params_->row_, params_->col_, col_tile_);
|
||||
ms_context_->allocator->Free(output_data_);
|
||||
|
|
|
@ -49,8 +49,10 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
|
|||
public:
|
||||
int ParallelRunByOC(int task_id) const;
|
||||
int ParallelRunByBatch(int task_id) const;
|
||||
int ParallelRunIsNotPackByBatch(int task_id) const;
|
||||
using ParallelRun = int (MatmulFp32BaseCPUKernel::*)(int task_id) const;
|
||||
ParallelRun parallel_fun_ = nullptr;
|
||||
bool is_pack_ = true;
|
||||
|
||||
protected:
|
||||
int InitBufferA();
|
||||
|
|
Loading…
Reference in New Issue