forked from mindspore-Ecosystem/mindspore
[MS][LITE] fullconnection matmul A B matrix const node bug
This commit is contained in:
parent
b32c5c551e
commit
7aab3f07b4
|
@ -23,6 +23,11 @@ using mindspore::lite::RET_OK;
|
||||||
|
|
||||||
namespace mindspore::kernel {
|
namespace mindspore::kernel {
|
||||||
FullconnectionCPUKernel::~FullconnectionCPUKernel() {
|
FullconnectionCPUKernel::~FullconnectionCPUKernel() {
|
||||||
|
FreeBuf();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void FullconnectionCPUKernel::FreeBuf() {
|
||||||
if (a_c8_ptr_ != nullptr) {
|
if (a_c8_ptr_ != nullptr) {
|
||||||
free(a_c8_ptr_);
|
free(a_c8_ptr_);
|
||||||
a_c8_ptr_ = nullptr;
|
a_c8_ptr_ = nullptr;
|
||||||
|
@ -41,7 +46,11 @@ FullconnectionCPUKernel::~FullconnectionCPUKernel() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int FullconnectionCPUKernel::ReSize() { return RET_OK; }
|
int FullconnectionCPUKernel::ReSize() {
|
||||||
|
FreeBuf();
|
||||||
|
Init();
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int FullconnectionCPUKernel::Init() {
|
int FullconnectionCPUKernel::Init() {
|
||||||
if (context_->infer_shape_interrupt_ && !context_->running_) {
|
if (context_->infer_shape_interrupt_ && !context_->running_) {
|
||||||
|
@ -75,16 +84,44 @@ int FullconnectionCPUKernel::Init() {
|
||||||
return RET_MEMORY_FAILED;
|
return RET_MEMORY_FAILED;
|
||||||
}
|
}
|
||||||
memset(b_r8_ptr_, 0, fc_param_->col_8_ * fc_param_->deep_ * sizeof(float));
|
memset(b_r8_ptr_, 0, fc_param_->col_8_ * fc_param_->deep_ * sizeof(float));
|
||||||
RowMajor2Col8Major(reinterpret_cast<float *>(in_tensors_[1]->Data()), b_r8_ptr_, fc_param_->col_, fc_param_->deep_);
|
|
||||||
|
|
||||||
c_r8x8_ptr_ = reinterpret_cast<float *>(malloc(fc_param_->row_8_ * fc_param_->col_8_ * sizeof(float)));
|
c_r8x8_ptr_ = reinterpret_cast<float *>(malloc(fc_param_->row_8_ * fc_param_->col_8_ * sizeof(float)));
|
||||||
if (c_r8x8_ptr_ == nullptr) {
|
if (c_r8x8_ptr_ == nullptr) {
|
||||||
return RET_MEMORY_FAILED;
|
return RET_MEMORY_FAILED;
|
||||||
}
|
}
|
||||||
memset(c_r8x8_ptr_, 0, fc_param_->row_8_ * fc_param_->col_8_ * sizeof(float));
|
memset(c_r8x8_ptr_, 0, fc_param_->row_8_ * fc_param_->col_8_ * sizeof(float));
|
||||||
|
|
||||||
|
fc_param_->a_const_ = false;
|
||||||
|
fc_param_->b_const_ = false;
|
||||||
|
InitMatrixA(reinterpret_cast<float *>(in_tensors_[0]->Data()), a_c8_ptr_);
|
||||||
|
InitMatrixB(reinterpret_cast<float *>(in_tensors_[1]->Data()), b_r8_ptr_);
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void FullconnectionCPUKernel::InitMatrixA(float *src_ptr, float *dst_ptr) {
|
||||||
|
if (fc_param_->a_const_ == true) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (src_ptr == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
fc_param_->a_const_ = true;
|
||||||
|
RowMajor2Col8Major(src_ptr, a_c8_ptr_, fc_param_->row_, fc_param_->deep_);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void FullconnectionCPUKernel::InitMatrixB(float *src_ptr, float *dst_ptr) {
|
||||||
|
if (fc_param_->b_const_ == true) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (src_ptr == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
fc_param_->b_const_ = true;
|
||||||
|
RowMajor2Col8Major(src_ptr, dst_ptr, fc_param_->col_, fc_param_->deep_);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
int FcFp32MatmulRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
|
int FcFp32MatmulRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
|
||||||
auto fc = reinterpret_cast<FullconnectionCPUKernel *>(cdata);
|
auto fc = reinterpret_cast<FullconnectionCPUKernel *>(cdata);
|
||||||
auto error_code = fc->DoMatmul(task_id);
|
auto error_code = fc->DoMatmul(task_id);
|
||||||
|
@ -115,9 +152,11 @@ int FullconnectionCPUKernel::Run() {
|
||||||
return prepare_ret;
|
return prepare_ret;
|
||||||
}
|
}
|
||||||
auto a_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->Data());
|
auto a_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->Data());
|
||||||
|
auto b_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->Data());
|
||||||
auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->Data());
|
auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->Data());
|
||||||
|
|
||||||
RowMajor2Col8Major(a_ptr, a_c8_ptr_, fc_param_->row_, fc_param_->deep_);
|
InitMatrixA(a_ptr, a_c8_ptr_);
|
||||||
|
InitMatrixB(b_ptr, b_r8_ptr_);
|
||||||
|
|
||||||
LiteBackendParallelLaunch(FcFp32MatmulRun, this, thread_count_);
|
LiteBackendParallelLaunch(FcFp32MatmulRun, this, thread_count_);
|
||||||
|
|
||||||
|
|
|
@ -40,6 +40,11 @@ class FullconnectionCPUKernel : public FullconnectionBaseCPUKernel {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
int DoMatmul(int task_id);
|
int DoMatmul(int task_id);
|
||||||
|
void FreeBuf();
|
||||||
|
|
||||||
|
private:
|
||||||
|
void InitMatrixA(float *src_ptr, float *dst_ptr);
|
||||||
|
void InitMatrixB(float *src_ptr, float *dst_ptr);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
float *a_c8_ptr_;
|
float *a_c8_ptr_;
|
||||||
|
|
|
@ -78,6 +78,11 @@ int MatmulCPUKernel::Init() {
|
||||||
}
|
}
|
||||||
memset(c_r8x8_ptr_, 0, params_->row_8_ * params_->col_8_ * sizeof(float));
|
memset(c_r8x8_ptr_, 0, params_->row_8_ * params_->col_8_ * sizeof(float));
|
||||||
|
|
||||||
|
params_->a_const_ = false;
|
||||||
|
params_->b_const_ = false;
|
||||||
|
InitMatrixA(reinterpret_cast<float *>(in_tensors_[0]->Data()), a_c8_ptr_);
|
||||||
|
InitMatrixB(reinterpret_cast<float *>(in_tensors_[1]->Data()), b_r8_ptr_);
|
||||||
|
|
||||||
if (in_tensors_.size() == 3) {
|
if (in_tensors_.size() == 3) {
|
||||||
bias_ptr_ = reinterpret_cast<float *>(malloc(params_->col_8_ * sizeof(float)));
|
bias_ptr_ = reinterpret_cast<float *>(malloc(params_->col_8_ * sizeof(float)));
|
||||||
memset(bias_ptr_, 0, params_->col_8_ * sizeof(float));
|
memset(bias_ptr_, 0, params_->col_8_ * sizeof(float));
|
||||||
|
@ -89,6 +94,40 @@ int MatmulCPUKernel::Init() {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void MatmulCPUKernel::InitMatrixA(float *src_ptr, float *dst_ptr) {
|
||||||
|
if (params_->a_const_ == true) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (src_ptr == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
params_->a_const_ = true;
|
||||||
|
|
||||||
|
if (params_->a_transpose_) {
|
||||||
|
RowMajor2Row8Major(src_ptr, dst_ptr, params_->deep_, params_->row_);
|
||||||
|
} else {
|
||||||
|
RowMajor2Col8Major(src_ptr, a_c8_ptr_, params_->row_, params_->deep_);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void MatmulCPUKernel::InitMatrixB(float *src_ptr, float *dst_ptr) {
|
||||||
|
if (params_->b_const_ == true) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (src_ptr == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
params_->b_const_ = true;
|
||||||
|
|
||||||
|
if (params_->b_transpose_) {
|
||||||
|
RowMajor2Col8Major(src_ptr, dst_ptr, params_->col_, params_->deep_);
|
||||||
|
} else {
|
||||||
|
RowMajor2Row8Major(src_ptr, dst_ptr, params_->deep_, params_->col_);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
int MatmulCPUKernel::RunImpl(int task_id) {
|
int MatmulCPUKernel::RunImpl(int task_id) {
|
||||||
int cur_oc = MSMIN(thread_stride_, UP_DIV(params_->col_8_, 8) - task_id * thread_stride_);
|
int cur_oc = MSMIN(thread_stride_, UP_DIV(params_->col_8_, 8) - task_id * thread_stride_);
|
||||||
if (cur_oc <= 0) {
|
if (cur_oc <= 0) {
|
||||||
|
@ -131,16 +170,10 @@ int MatmulCPUKernel::Run() {
|
||||||
auto cur_a_ptr = a_ptr + i * a_stride;
|
auto cur_a_ptr = a_ptr + i * a_stride;
|
||||||
auto cur_b_ptr = b_ptr + i * b_stride;
|
auto cur_b_ptr = b_ptr + i * b_stride;
|
||||||
auto cur_c_ptr = c_ptr + i * c_stride;
|
auto cur_c_ptr = c_ptr + i * c_stride;
|
||||||
if (params_->a_transpose_) {
|
|
||||||
RowMajor2Row8Major(cur_a_ptr, a_c8_ptr_, params_->deep_, params_->row_);
|
InitMatrixA(cur_a_ptr, a_c8_ptr_);
|
||||||
} else {
|
InitMatrixB(cur_b_ptr, b_r8_ptr_);
|
||||||
RowMajor2Col8Major(cur_a_ptr, a_c8_ptr_, params_->row_, params_->deep_);
|
|
||||||
}
|
|
||||||
if (params_->b_transpose_) {
|
|
||||||
RowMajor2Col8Major(cur_b_ptr, b_r8_ptr_, params_->col_, params_->deep_);
|
|
||||||
} else {
|
|
||||||
RowMajor2Row8Major(cur_b_ptr, b_r8_ptr_, params_->deep_, params_->col_);
|
|
||||||
}
|
|
||||||
LiteBackendParallelLaunch(MatmulFloatRun, this, thread_count_);
|
LiteBackendParallelLaunch(MatmulFloatRun, this, thread_count_);
|
||||||
Row8x8Major2RowMajor(c_r8x8_ptr_, cur_c_ptr, params_->row_, params_->col_, params_->col_);
|
Row8x8Major2RowMajor(c_r8x8_ptr_, cur_c_ptr, params_->row_, params_->col_, params_->col_);
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,6 +35,10 @@ class MatmulCPUKernel : public MatmulBaseCPUKernel {
|
||||||
int Run() override;
|
int Run() override;
|
||||||
int RunImpl(int task_id);
|
int RunImpl(int task_id);
|
||||||
|
|
||||||
|
private:
|
||||||
|
void InitMatrixA(float *src_ptr, float *dst_ptr);
|
||||||
|
void InitMatrixB(float *src_ptr, float *dst_ptr);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
float *a_c8_ptr_;
|
float *a_c8_ptr_;
|
||||||
float *b_r8_ptr_;
|
float *b_r8_ptr_;
|
||||||
|
|
|
@ -33,6 +33,8 @@ typedef struct MatMulParameter {
|
||||||
int batch;
|
int batch;
|
||||||
bool a_transpose_; /* false : row-major */
|
bool a_transpose_; /* false : row-major */
|
||||||
bool b_transpose_; /* true : col-major */
|
bool b_transpose_; /* true : col-major */
|
||||||
|
bool a_const_;
|
||||||
|
bool b_const_;
|
||||||
ActType act_type_;
|
ActType act_type_;
|
||||||
} MatMulParameter;
|
} MatMulParameter;
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue