forked from mindspore-Ecosystem/mindspore
!6282 [MSLITE][Develop] fix deconv fp16 bug
Merge pull request !6282 from ling/deconv
This commit is contained in:
commit
ea5e17ec7a
|
@ -82,8 +82,8 @@ int DeConvolutionFp16CPUKernel::InitParam() {
|
|||
matmul_param_->row_ = input_plane_;
|
||||
matmul_param_->deep_ = conv_param_->input_channel_;
|
||||
matmul_param_->col_ = conv_param_->output_channel_ * kernel_plane_;
|
||||
row16_ = UP_ROUND(matmul_param_->row_, 16);
|
||||
col8_ = UP_ROUND(matmul_param_->col_, 8);
|
||||
matmul_param_->row_16_ = UP_ROUND(matmul_param_->row_, C16NUM);
|
||||
matmul_param_->col_8_ = UP_ROUND(conv_param_->output_channel_, C8NUM) * kernel_plane_;
|
||||
|
||||
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(conv_param_->output_channel_, C8NUM));
|
||||
thread_stride_ = UP_DIV(UP_DIV(conv_param_->output_channel_, C8NUM), thread_count_);
|
||||
|
@ -98,13 +98,15 @@ int DeConvolutionFp16CPUKernel::InitRunBuf() {
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
tmp_buffer_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(row16_ * col8_ * sizeof(float16_t)));
|
||||
tmp_buffer_ = reinterpret_cast<float16_t *>(
|
||||
ctx_->allocator->Malloc(matmul_param_->row_16_ * matmul_param_->col_8_ * sizeof(float16_t)));
|
||||
if (tmp_buffer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "deconv Malloc tmp_buffer_ error!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
pack_input_ = reinterpret_cast<float16_t *>(malloc(row16_ * matmul_param_->deep_ * sizeof(float16_t)));
|
||||
pack_input_ =
|
||||
reinterpret_cast<float16_t *>(malloc(matmul_param_->row_16_ * matmul_param_->deep_ * sizeof(float16_t)));
|
||||
if (pack_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "deconv Malloc pack_input_ error!";
|
||||
return RET_ERROR;
|
||||
|
@ -147,7 +149,7 @@ int DeConvolutionFp16CPUKernel::DoDeconv(int task_id) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
auto tmp_buf = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * row16_;
|
||||
auto tmp_buf = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_16_;
|
||||
MatMulFp16(pack_input_, execute_weight_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_,
|
||||
tmp_buf, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_, oc * C8NUM * kernel_plane_, 0,
|
||||
false);
|
||||
|
|
|
@ -57,8 +57,6 @@ class DeConvolutionFp16CPUKernel : public ConvolutionBaseFP16CPUKernel {
|
|||
|
||||
private:
|
||||
MatMulParameter *matmul_param_;
|
||||
int row16_;
|
||||
int col8_;
|
||||
int input_plane_;
|
||||
int kernel_plane_;
|
||||
int output_plane_;
|
||||
|
|
Loading…
Reference in New Issue