From 7aab3f07b40dc57c07121a1eb4188cd8db02e436 Mon Sep 17 00:00:00 2001 From: ling Date: Mon, 17 Aug 2020 18:55:16 +0800 Subject: [PATCH] [MS][LITE] fullconnection matmul A B matrix const node bug --- .../runtime/kernel/arm/fp32/fullconnection.cc | 45 ++++++++++++++-- .../runtime/kernel/arm/fp32/fullconnection.h | 5 ++ .../src/runtime/kernel/arm/fp32/matmul.cc | 53 +++++++++++++++---- .../lite/src/runtime/kernel/arm/fp32/matmul.h | 4 ++ .../kernel/arm/nnacl/matmul_parameter.h | 2 + 5 files changed, 96 insertions(+), 13 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.cc index 096127e184d..99ff4a5b0b0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.cc @@ -23,6 +23,11 @@ using mindspore::lite::RET_OK; namespace mindspore::kernel { FullconnectionCPUKernel::~FullconnectionCPUKernel() { + FreeBuf(); + return; +} + +void FullconnectionCPUKernel::FreeBuf() { if (a_c8_ptr_ != nullptr) { free(a_c8_ptr_); a_c8_ptr_ = nullptr; @@ -41,7 +46,11 @@ FullconnectionCPUKernel::~FullconnectionCPUKernel() { } } -int FullconnectionCPUKernel::ReSize() { return RET_OK; } +int FullconnectionCPUKernel::ReSize() { + FreeBuf(); + Init(); + return RET_OK; +} int FullconnectionCPUKernel::Init() { if (context_->infer_shape_interrupt_ && !context_->running_) { @@ -75,16 +84,44 @@ int FullconnectionCPUKernel::Init() { return RET_MEMORY_FAILED; } memset(b_r8_ptr_, 0, fc_param_->col_8_ * fc_param_->deep_ * sizeof(float)); - RowMajor2Col8Major(reinterpret_cast(in_tensors_[1]->Data()), b_r8_ptr_, fc_param_->col_, fc_param_->deep_); c_r8x8_ptr_ = reinterpret_cast(malloc(fc_param_->row_8_ * fc_param_->col_8_ * sizeof(float))); if (c_r8x8_ptr_ == nullptr) { return RET_MEMORY_FAILED; } memset(c_r8x8_ptr_, 0, fc_param_->row_8_ * fc_param_->col_8_ * sizeof(float)); + + fc_param_->a_const_ = false; + fc_param_->b_const_ = false; + InitMatrixA(reinterpret_cast(in_tensors_[0]->Data()), a_c8_ptr_); + InitMatrixB(reinterpret_cast(in_tensors_[1]->Data()), b_r8_ptr_); return RET_OK; } +void FullconnectionCPUKernel::InitMatrixA(float *src_ptr, float *dst_ptr) { + if (fc_param_->a_const_ == true) { + return; + } + if (src_ptr == nullptr) { + return; + } + fc_param_->a_const_ = true; + RowMajor2Col8Major(src_ptr, a_c8_ptr_, fc_param_->row_, fc_param_->deep_); + return; +} + +void FullconnectionCPUKernel::InitMatrixB(float *src_ptr, float *dst_ptr) { + if (fc_param_->b_const_ == true) { + return; + } + if (src_ptr == nullptr) { + return; + } + fc_param_->b_const_ = true; + RowMajor2Col8Major(src_ptr, dst_ptr, fc_param_->col_, fc_param_->deep_); + return; +} + int FcFp32MatmulRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { auto fc = reinterpret_cast(cdata); auto error_code = fc->DoMatmul(task_id); @@ -115,9 +152,11 @@ int FullconnectionCPUKernel::Run() { return prepare_ret; } auto a_ptr = reinterpret_cast(in_tensors_.at(0)->Data()); + auto b_ptr = reinterpret_cast(in_tensors_.at(1)->Data()); auto output_ptr = reinterpret_cast(out_tensors_.at(0)->Data()); - RowMajor2Col8Major(a_ptr, a_c8_ptr_, fc_param_->row_, fc_param_->deep_); + InitMatrixA(a_ptr, a_c8_ptr_); + InitMatrixB(b_ptr, b_r8_ptr_); LiteBackendParallelLaunch(FcFp32MatmulRun, this, thread_count_); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.h b/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.h index b8b0d5defe1..9f66b037575 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.h @@ -40,6 +40,11 @@ class FullconnectionCPUKernel : public FullconnectionBaseCPUKernel { public: int DoMatmul(int task_id); + void FreeBuf(); + + private: + void InitMatrixA(float *src_ptr, float *dst_ptr); + void InitMatrixB(float *src_ptr, float *dst_ptr); private: float *a_c8_ptr_; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc index 2e35323ba6d..5aa685be403 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc @@ -78,6 +78,11 @@ int MatmulCPUKernel::Init() { } memset(c_r8x8_ptr_, 0, params_->row_8_ * params_->col_8_ * sizeof(float)); + params_->a_const_ = false; + params_->b_const_ = false; + InitMatrixA(reinterpret_cast(in_tensors_[0]->Data()), a_c8_ptr_); + InitMatrixB(reinterpret_cast(in_tensors_[1]->Data()), b_r8_ptr_); + if (in_tensors_.size() == 3) { bias_ptr_ = reinterpret_cast(malloc(params_->col_8_ * sizeof(float))); memset(bias_ptr_, 0, params_->col_8_ * sizeof(float)); @@ -89,6 +94,40 @@ int MatmulCPUKernel::Init() { return RET_OK; } +void MatmulCPUKernel::InitMatrixA(float *src_ptr, float *dst_ptr) { + if (params_->a_const_ == true) { + return; + } + if (src_ptr == nullptr) { + return; + } + params_->a_const_ = true; + + if (params_->a_transpose_) { + RowMajor2Row8Major(src_ptr, dst_ptr, params_->deep_, params_->row_); + } else { + RowMajor2Col8Major(src_ptr, a_c8_ptr_, params_->row_, params_->deep_); + } + return; +} + +void MatmulCPUKernel::InitMatrixB(float *src_ptr, float *dst_ptr) { + if (params_->b_const_ == true) { + return; + } + if (src_ptr == nullptr) { + return; + } + params_->b_const_ = true; + + if (params_->b_transpose_) { + RowMajor2Col8Major(src_ptr, dst_ptr, params_->col_, params_->deep_); + } else { + RowMajor2Row8Major(src_ptr, dst_ptr, params_->deep_, params_->col_); + } + return; +} + int MatmulCPUKernel::RunImpl(int task_id) { int cur_oc = MSMIN(thread_stride_, UP_DIV(params_->col_8_, 8) - task_id * thread_stride_); if (cur_oc <= 0) { @@ -131,16 +170,10 @@ int MatmulCPUKernel::Run() { auto cur_a_ptr = a_ptr + i * a_stride; auto cur_b_ptr = b_ptr + i * b_stride; auto cur_c_ptr = c_ptr + i * c_stride; - if (params_->a_transpose_) { - RowMajor2Row8Major(cur_a_ptr, a_c8_ptr_, params_->deep_, params_->row_); - } else { - RowMajor2Col8Major(cur_a_ptr, a_c8_ptr_, params_->row_, params_->deep_); - } - if (params_->b_transpose_) { - RowMajor2Col8Major(cur_b_ptr, b_r8_ptr_, params_->col_, params_->deep_); - } else { - RowMajor2Row8Major(cur_b_ptr, b_r8_ptr_, params_->deep_, params_->col_); - } + + InitMatrixA(cur_a_ptr, a_c8_ptr_); + InitMatrixB(cur_b_ptr, b_r8_ptr_); + LiteBackendParallelLaunch(MatmulFloatRun, this, thread_count_); Row8x8Major2RowMajor(c_r8x8_ptr_, cur_c_ptr, params_->row_, params_->col_, params_->col_); } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.h b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.h index 4642c8ad6c0..38c0e445ac2 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.h @@ -35,6 +35,10 @@ class MatmulCPUKernel : public MatmulBaseCPUKernel { int Run() override; int RunImpl(int task_id); + private: + void InitMatrixA(float *src_ptr, float *dst_ptr); + void InitMatrixB(float *src_ptr, float *dst_ptr); + private: float *a_c8_ptr_; float *b_r8_ptr_; diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/matmul_parameter.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/matmul_parameter.h index be01e4beb29..b2b24064d35 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/matmul_parameter.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/matmul_parameter.h @@ -33,6 +33,8 @@ typedef struct MatMulParameter { int batch; bool a_transpose_; /* false : row-major */ bool b_transpose_; /* true : col-major */ + bool a_const_; + bool b_const_; ActType act_type_; } MatMulParameter;