forked from OSSInnovation/mindspore
!6780 [MSLITE] Fix the bug of fp16 fc tensor pack
Merge pull request !6780 from zhanyuan/dev
This commit is contained in:
commit
d9076f2ca2
|
@ -78,7 +78,15 @@ int FullconnectionFP16CPUKernel::ReSize() {
|
|||
}
|
||||
memset(b_pack_ptr_, 0, fc_param_->col_8_ * fc_param_->deep_ * sizeof(float16_t));
|
||||
|
||||
InitMatrixB(reinterpret_cast<float *>(in_tensors_[1]->data_c()), b_pack_ptr_);
|
||||
fc_param_->b_const_ = (in_tensors_[1]->data_c() != nullptr);
|
||||
if (fc_param_->b_const_) {
|
||||
if (in_tensors_[1]->data_type() == kNumberTypeFloat32) {
|
||||
InitMatrixB(reinterpret_cast<float *>(in_tensors_[1]->data_c()), b_pack_ptr_);
|
||||
} else {
|
||||
InitMatrixB(reinterpret_cast<float16_t *>(in_tensors_[1]->data_c()), b_pack_ptr_);
|
||||
}
|
||||
}
|
||||
|
||||
if (in_tensors_.size() == 3) {
|
||||
bias_ptr_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(fc_param_->col_8_ * sizeof(float16_t)));
|
||||
if (bias_ptr_ == nullptr) {
|
||||
|
@ -108,6 +116,10 @@ void FullconnectionFP16CPUKernel::InitMatrixB(float *b_ptr, float16_t *b_pack_pt
|
|||
RowMajor2Col8MajorFp16(reinterpret_cast<void *>(b_ptr), b_pack_ptr, fc_param_->col_, fc_param_->deep_, true);
|
||||
}
|
||||
|
||||
void FullconnectionFP16CPUKernel::InitMatrixB(float16_t *b_ptr, float16_t *b_pack_ptr) {
|
||||
RowMajor2Col8MajorFp16(reinterpret_cast<void *>(b_ptr), b_pack_ptr, fc_param_->col_, fc_param_->deep_, false);
|
||||
}
|
||||
|
||||
int FullconnectionFP16CPUKernel::Init() {
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
|
@ -156,6 +168,13 @@ int FullconnectionFP16CPUKernel::Run() {
|
|||
} else {
|
||||
InitMatrixA(reinterpret_cast<float16_t *>(in_tensors_[0]->data_c()), a_pack_ptr_);
|
||||
}
|
||||
if (!fc_param_->b_const_) {
|
||||
if (in_tensors_[1]->data_type() == kNumberTypeFloat32) {
|
||||
InitMatrixB(reinterpret_cast<float *>(in_tensors_[1]->data_c()), b_pack_ptr_);
|
||||
} else {
|
||||
InitMatrixB(reinterpret_cast<float16_t *>(in_tensors_[1]->data_c()), b_pack_ptr_);
|
||||
}
|
||||
}
|
||||
ParallelLaunch(this->context_->thread_pool_, FcFP16Run, this, thread_count_);
|
||||
if (out_tensor->data_type() == kNumberTypeFloat32) {
|
||||
auto size = out_tensor->ElementsNum();
|
||||
|
|
|
@ -42,6 +42,7 @@ class FullconnectionFP16CPUKernel : public FullconnectionBaseCPUKernel {
|
|||
void InitMatrixA(float *a_ptr, float16_t *a_pack_ptr);
|
||||
void InitMatrixA(float16_t *a_ptr, float16_t *a_pack_ptr);
|
||||
void InitMatrixB(float *b_ptr, float16_t *b_pack_ptr);
|
||||
void InitMatrixB(float16_t *b_ptr, float16_t *b_pack_ptr);
|
||||
void FreeTmpBuffer();
|
||||
|
||||
private:
|
||||
|
|
Loading…
Reference in New Issue