!31299 fix dynamic resize bug for matmul fp16

Merge pull request !31299 from yeyunpeng2020/master
This commit is contained in:
i-robot 2022-03-16 02:57:00 +00:00 committed by Gitee
commit 0373d2f915
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 35 additions and 10 deletions

View File

@ -33,8 +33,9 @@
// x9: writeMode // x9: writeMode
asm_function MatmulBaseFp16Neon asm_function MatmulBaseFp16Neon
sub sp, sp, #96 sub sp, sp, #160
st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64 st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64
st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [sp], #64
stp x19, x20, [sp], #16 stp x19, x20, [sp], #16
stp x21, x22, [sp], #16 stp x21, x22, [sp], #16
@ -950,8 +951,9 @@ LoopColEnd:
add x0, x0, x15 add x0, x0, x15
bgt LoopRowStart bgt LoopRowStart
sub sp, sp, #96 sub sp, sp, #160
ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64 ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64
ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [sp], #64
ldp x19, x20, [sp], #16 ldp x19, x20, [sp], #16
ldp x21, x22, [sp], #16 ldp x21, x22, [sp], #16
ret ret

View File

@ -76,20 +76,36 @@ void MatmulBaseFP16CPUKernel::InitParameter() {
} }
int MatmulBaseFP16CPUKernel::InitBias() { int MatmulBaseFP16CPUKernel::InitBias() {
if (params_->col_ != 0 && bias_ptr_ == nullptr) { int max_bias_data = 0;
int max_bias_data = UP_ROUND(params_->col_, C16NUM); if (params_->col_ == 0) {
if (in_tensors().size() == C3NUM) {
max_bias_data = in_tensors().at(THIRD_INPUT)->ElementsNum();
}
} else {
max_bias_data = UP_ROUND(params_->col_, C16NUM);
}
if (max_bias_data > bias_count_) {
auto bias_ptr_bak = bias_ptr_;
bias_ptr_ = reinterpret_cast<float16_t *>(malloc(max_bias_data * sizeof(float16_t))); bias_ptr_ = reinterpret_cast<float16_t *>(malloc(max_bias_data * sizeof(float16_t)));
if (bias_ptr_ == nullptr) { if (bias_ptr_ == nullptr) {
MS_LOG(ERROR) << "malloc bias_ptr_ failed"; MS_LOG(ERROR) << "malloc bias_ptr_ failed";
return RET_ERROR; return RET_ERROR;
} }
if (in_tensors_.size() == 3) { if (bias_count_ == 0) {
auto bias_tensor = in_tensors_[2]; if (in_tensors_.size() == C3NUM) {
CHECK_NULL_RETURN(bias_tensor); auto bias_tensor = in_tensors_[THIRD_INPUT];
memcpy(bias_ptr_, bias_tensor->data(), bias_tensor->ElementsNum() * sizeof(float16_t)); CHECK_NULL_RETURN(bias_tensor);
memcpy(bias_ptr_, bias_tensor->data(), bias_tensor->ElementsNum() * sizeof(float16_t));
} else {
memset(bias_ptr_, 0, max_bias_data * sizeof(float16_t));
}
} else { } else {
memset(bias_ptr_, 0, max_bias_data * sizeof(float16_t)); memset(bias_ptr_, 0, max_bias_data * sizeof(float16_t));
memcpy(bias_ptr_, bias_ptr_bak, bias_count_ * sizeof(float16_t));
free(bias_ptr_bak);
bias_ptr_bak = nullptr;
} }
bias_count_ = max_bias_data;
} }
return RET_OK; return RET_OK;
} }
@ -115,7 +131,11 @@ int MatmulBaseFP16CPUKernel::ReSize() {
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(params_->col_, C8NUM)); thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(params_->col_, C8NUM));
thread_stride_ = UP_DIV(UP_DIV(params_->col_, C8NUM), thread_count_) * C8NUM; thread_stride_ = UP_DIV(UP_DIV(params_->col_, C8NUM), thread_count_) * C8NUM;
} }
auto ret = InitBias();
if (ret != RET_OK) {
MS_LOG(ERROR) << "InitBias failed";
return RET_ERROR;
}
return RET_OK; return RET_OK;
} }
@ -250,6 +270,9 @@ void MatmulBaseFP16CPUKernel::InitMatrixB(const void *src_ptr, TypeId src_data_t
int MatmulBaseFP16CPUKernel::Prepare() { int MatmulBaseFP16CPUKernel::Prepare() {
CHECK_LESS_RETURN(in_tensors_.size(), 2); CHECK_LESS_RETURN(in_tensors_.size(), 2);
CHECK_LESS_RETURN(out_tensors_.size(), 1); CHECK_LESS_RETURN(out_tensors_.size(), 1);
if (in_tensors_.size() == FOURTH_INPUT) {
MS_CHECK_TRUE_MSG(in_tensors_[THIRD_INPUT]->IsConst(), RET_ERROR, "matrix-c must be const when existing.");
}
ResizeParameter(); ResizeParameter();
if (params_->a_const_ == true) { if (params_->a_const_ == true) {
if (RET_OK != InitBufferA()) { if (RET_OK != InitBufferA()) {
@ -330,7 +353,6 @@ int MatmulBaseFP16CPUKernel::Run() {
return RET_ERROR; return RET_ERROR;
} }
InitMatrixB(in_tensors_[1]->data(), in_tensors_[1]->data_type()); InitMatrixB(in_tensors_[1]->data(), in_tensors_[1]->data_type());
InitBias();
} }
CHECK_NULL_RETURN(c_ptr); CHECK_NULL_RETURN(c_ptr);

View File

@ -67,6 +67,7 @@ class MatmulBaseFP16CPUKernel : public InnerKernel {
private: private:
int thread_stride_ = 0; int thread_stride_ = 0;
int thread_count_ = 0; int thread_count_ = 0;
int bias_count_ = 0;
bool vec_matmul_ = false; bool vec_matmul_ = false;
float16_t *a_pack_ptr_ = nullptr; float16_t *a_pack_ptr_ = nullptr;
float16_t *b_pack_ptr_ = nullptr; float16_t *b_pack_ptr_ = nullptr;