forked from mindspore-Ecosystem/mindspore
!7820 [MSLITE] Fix the bug of MatVecMul fp32
Merge pull request !7820 from zhanyuan/dev
This commit is contained in:
commit
a9eedbc0ea
|
@ -49,8 +49,8 @@ typedef struct MatMulParameter {
|
|||
bool b_transpose_; /* true : col-major */
|
||||
bool a_const_;
|
||||
bool b_const_;
|
||||
bool a_has_shape_;
|
||||
bool b_has_shape_;
|
||||
bool a_init_shape_;
|
||||
bool b_init_shape_;
|
||||
ActType act_type_;
|
||||
} MatMulParameter;
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ int MatmulCPUKernel::MallocMatrixABuffer() {
|
|||
params_->batch = batch;
|
||||
params_->row_ = params_->a_transpose_ ? a_shape[a_shape.size() - 1] : a_shape[a_shape.size() - 2];
|
||||
#ifdef ENABLE_ARM64
|
||||
if (params_->row_ == 1) {
|
||||
if (params_->a_init_shape_ && params_->row_ == 1) {
|
||||
is_vector_a_ = true;
|
||||
}
|
||||
#endif
|
||||
|
@ -134,7 +134,7 @@ int MatmulCPUKernel::InitBias() {
|
|||
}
|
||||
|
||||
int MatmulCPUKernel::ReSize() {
|
||||
if (params_->a_const_ == false || params_->a_has_shape_ == false) {
|
||||
if (params_->a_const_ == false || params_->a_init_shape_ == false) {
|
||||
if (a_pack_ptr_ != nullptr) {
|
||||
free(a_pack_ptr_);
|
||||
a_pack_ptr_ = nullptr;
|
||||
|
@ -145,7 +145,7 @@ int MatmulCPUKernel::ReSize() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
if (params_->b_const_ == false || params_->b_has_shape_ == false) {
|
||||
if (params_->b_const_ == false || params_->b_init_shape_ == false) {
|
||||
if (b_pack_ptr_ != nullptr) {
|
||||
free(b_pack_ptr_);
|
||||
b_pack_ptr_ = nullptr;
|
||||
|
@ -222,16 +222,16 @@ void MatmulCPUKernel::InitMatrixB(float *src_ptr, float *dst_ptr) {
|
|||
}
|
||||
|
||||
int MatmulCPUKernel::Init() {
|
||||
params_->a_has_shape_ = (in_tensors_[0]->shape().size() != 0);
|
||||
params_->b_has_shape_ = (in_tensors_[1]->shape().size() != 0);
|
||||
if (params_->a_has_shape_ == true) {
|
||||
params_->a_init_shape_ = (in_tensors_[0]->shape().size() != 0);
|
||||
params_->b_init_shape_ = (in_tensors_[1]->shape().size() != 0);
|
||||
if (params_->a_init_shape_ == true) {
|
||||
auto ret = MallocMatrixABuffer();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Matmul fp32 malloc matrix a buffer failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
if (params_->b_has_shape_ == true) {
|
||||
if (params_->b_init_shape_ == true) {
|
||||
auto ret = MallocMatrixBBuffer();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Matmul fp32 malloc matrix b buffer failed";
|
||||
|
@ -300,7 +300,7 @@ int MatmulCPUKernel::Run() {
|
|||
}
|
||||
}
|
||||
if (params_->b_const_ == false || is_train()) {
|
||||
if (is_vector_a_) {
|
||||
if (is_vector_a_ && params_->b_transpose_) {
|
||||
b_ptr_ = b_src;
|
||||
} else {
|
||||
InitMatrixB(b_src, b_pack_ptr_);
|
||||
|
|
Loading…
Reference in New Issue