!7820 [MSLITE] Fix the bug of MatVecMul fp32

Merge pull request !7820 from zhanyuan/dev
This commit is contained in:
mindspore-ci-bot 2020-10-27 16:22:42 +08:00 committed by Gitee
commit a9eedbc0ea
2 changed files with 10 additions and 10 deletions

View File

@ -49,8 +49,8 @@ typedef struct MatMulParameter {
bool b_transpose_; /* true : col-major */
bool a_const_;
bool b_const_;
bool a_has_shape_;
bool b_has_shape_;
bool a_init_shape_;
bool b_init_shape_;
ActType act_type_;
} MatMulParameter;

View File

@ -57,7 +57,7 @@ int MatmulCPUKernel::MallocMatrixABuffer() {
params_->batch = batch;
params_->row_ = params_->a_transpose_ ? a_shape[a_shape.size() - 1] : a_shape[a_shape.size() - 2];
#ifdef ENABLE_ARM64
if (params_->row_ == 1) {
if (params_->a_init_shape_ && params_->row_ == 1) {
is_vector_a_ = true;
}
#endif
@ -134,7 +134,7 @@ int MatmulCPUKernel::InitBias() {
}
int MatmulCPUKernel::ReSize() {
if (params_->a_const_ == false || params_->a_has_shape_ == false) {
if (params_->a_const_ == false || params_->a_init_shape_ == false) {
if (a_pack_ptr_ != nullptr) {
free(a_pack_ptr_);
a_pack_ptr_ = nullptr;
@ -145,7 +145,7 @@ int MatmulCPUKernel::ReSize() {
return RET_ERROR;
}
}
if (params_->b_const_ == false || params_->b_has_shape_ == false) {
if (params_->b_const_ == false || params_->b_init_shape_ == false) {
if (b_pack_ptr_ != nullptr) {
free(b_pack_ptr_);
b_pack_ptr_ = nullptr;
@ -222,16 +222,16 @@ void MatmulCPUKernel::InitMatrixB(float *src_ptr, float *dst_ptr) {
}
int MatmulCPUKernel::Init() {
params_->a_has_shape_ = (in_tensors_[0]->shape().size() != 0);
params_->b_has_shape_ = (in_tensors_[1]->shape().size() != 0);
if (params_->a_has_shape_ == true) {
params_->a_init_shape_ = (in_tensors_[0]->shape().size() != 0);
params_->b_init_shape_ = (in_tensors_[1]->shape().size() != 0);
if (params_->a_init_shape_ == true) {
auto ret = MallocMatrixABuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Matmul fp32 malloc matrix a buffer failed";
return RET_ERROR;
}
}
if (params_->b_has_shape_ == true) {
if (params_->b_init_shape_ == true) {
auto ret = MallocMatrixBBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Matmul fp32 malloc matrix b buffer failed";
@ -300,7 +300,7 @@ int MatmulCPUKernel::Run() {
}
}
if (params_->b_const_ == false || is_train()) {
if (is_vector_a_) {
if (is_vector_a_ && params_->b_transpose_) {
b_ptr_ = b_src;
} else {
InitMatrixB(b_src, b_pack_ptr_);