!18197 [MS][LITE][CPU] code check

Merge pull request !18197 from liuzhongkai/codecheck3
This commit is contained in:
i-robot 2021-06-11 15:47:46 +08:00 committed by Gitee
commit 5b8478696e
29 changed files with 307 additions and 112 deletions

View File

@ -60,8 +60,8 @@ int LReluFp16(const float16_t *src, float16_t *dst, int ele_num, float16_t alpha
for (; i < ele_c8; i += C8NUM) {
float16x8_t src_tmp = vld1q_f16(src + i);
float16x8_t mul_tmp = vmulq_n_f16(src_tmp, alpha);
float16x8_t mask = vcgtq_f16(src_tmp, vdupq_n_f16(0.0f));
vst1q_f16(dst + i, vbslq_f32(mask, src_tmp, mul_tmp));
uint16x8_t mask = vcgtq_f16(src_tmp, vdupq_n_f16(0.0f));
vst1q_f16(dst + i, vbslq_f16(mask, src_tmp, mul_tmp));
}
#endif
for (; i < ele_num; ++i) {

View File

@ -97,13 +97,8 @@ int ElementLogicalAnd(const float *in0, const float *in1, float *out, int size)
uint32x4_t mask = vmovq_n_u32(((uint32_t)(1u << 31) - 1));
uint32x4_t zeros = vdupq_n_u32(0);
for (; index <= size - 4; index += C4NUM) {
#ifndef SUPPORT_NNIE
uint32x4_t vin0 = vandq_u32(vreinterpretq_s32_f32(vld1q_f32(in0 + index)), mask);
uint32x4_t vin1 = vandq_u32(vreinterpretq_s32_f32(vld1q_f32(in1 + index)), mask);
#else
uint32x4_t vin0 = vandq_u32(vreinterpretq_u32_f32(vld1q_f32(in0 + index)), mask);
uint32x4_t vin1 = vandq_u32(vreinterpretq_u32_f32(vld1q_f32(in1 + index)), mask);
#endif
float32x4_t vout = vbslq_f32(vceqq_u32(vandq_u32(vin0, vin1), zeros), vfalse, vtrue);
vst1q_f32(out + index, vout);
}

View File

@ -541,6 +541,8 @@ void GemmMatmul(int ta, int tb, int M, int N, int K, float alpha, const float *m
gcb.ca = 0;
gcb.cb = 0;
gcb.bias = NULL;
gcb.mat_a = NULL;
gcb.mat_b = NULL;
GemmMatmulPlus(ta, tb, M, N, K, alpha, mat_a, lda, mat_b, ldb, beta, mat_c, ldc, workspace, &gcb);
}

View File

@ -56,10 +56,18 @@ int Convolution1x1FP16CPUKernel::InitConv1x1Param() {
if ((matmul_param_->row_ > (row_tile_ * op_parameter_->thread_num_)) && (matmul_param_->row_ > matmul_param_->col_)) {
multi_thread_by_hw_ = true;
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, row_tile_));
if (thread_count_ <= 0) {
MS_LOG(ERROR) << "thread_count_ must be greater than 0!";
return RET_ERROR;
}
thread_stride_ = UP_DIV(UP_DIV(matmul_param_->row_, row_tile_), thread_count_) * row_tile_;
} else {
multi_thread_by_hw_ = false;
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, col_tile_));
if (thread_count_ <= 0) {
MS_LOG(ERROR) << "thread_count_ must be greater than 0!";
return RET_ERROR;
}
thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, col_tile_), thread_count_) * col_tile_;
}
@ -212,6 +220,8 @@ static int Convolution1x1Fp16RunHw(void *cdata, int task_id, float lhs_scale, fl
int Convolution1x1FP16CPUKernel::Run() {
auto input_data = reinterpret_cast<float16_t *>(in_tensors_.at(0)->data_c());
auto output_data = reinterpret_cast<float16_t *>(out_tensors_.at(0)->data_c());
MS_ASSERT(input_data != nullptr);
MS_ASSERT(output_data != nullptr);
if (input_data == nullptr || output_data == nullptr) {
MS_LOG(ERROR) << "Convolution1x1 Fp16 get null tensor data!";
return RET_ERROR;

View File

@ -56,6 +56,7 @@ void *ConvolutionDelegateFP16CPUKernel::CopyData(lite::Tensor *tensor) {
MS_LOG(ERROR) << "Malloc copied_data failed.";
return nullptr;
}
MS_ASSERT(tensor->data_c() != nullptr);
memcpy(copied_data, tensor->data_c(), tensor->Size());
return copied_data;
}
@ -71,8 +72,10 @@ int ConvolutionDelegateFP16CPUKernel::Init() {
return RET_OK;
}
origin_weight_ = in_tensors_.at(kWeightIndex)->data_c();
MS_ASSERT(origin_weight_ != nullptr);
if (in_tensors_.size() == 3) {
origin_bias_ = in_tensors_.at(kBiasIndex)->data_c();
MS_ASSERT(origin_bias_ != nullptr);
}
return ReSize();
}

View File

@ -36,6 +36,7 @@ int ConvolutionDepthwiseFp16CPUKernel::InitWeightBias() {
int channel = weight_tensor->Batch();
int pack_weight_size = channel * weight_tensor->Height() * weight_tensor->Width();
auto origin_weight = reinterpret_cast<float16_t *>(weight_tensor->data_c());
MS_ASSERT(origin_weight != nullptr);
packed_weight_ = reinterpret_cast<float16_t *>(malloc(pack_weight_size * sizeof(float16_t)));
if (packed_weight_ == nullptr) {
@ -84,6 +85,8 @@ int ConvolutionDepthwiseFp16CPUKernel::ReSize() {
int ConvolutionDepthwiseFp16CPUKernel::Execute(int task_id) {
auto input_ptr = reinterpret_cast<float16_t *>(in_tensors_.at(0)->data_c());
auto output_ptr = reinterpret_cast<float16_t *>(out_tensors_.at(0)->data_c());
MS_ASSERT(input_ptr != nullptr);
MS_ASSERT(output_ptr != nullptr);
if (input_ptr == nullptr || output_ptr == nullptr) {
MS_LOG(ERROR) << "Convolution depthwise Fp16 get null tensor data!";
return RET_ERROR;

View File

@ -62,6 +62,7 @@ int ConvolutionDepthwiseSWFp16CPUKernel::InitWeightBias() {
int OC8 = UP_DIV(weight_tensor->Batch(), C8NUM);
int pack_weight_size = C8NUM * OC8 * weight_tensor->Height() * weight_tensor->Width();
auto origin_weight = reinterpret_cast<float16_t *>(weight_tensor->data_c());
MS_ASSERT(origin_weight != nullptr);
packed_weight_ = reinterpret_cast<float16_t *>(malloc(pack_weight_size * sizeof(float16_t)));
if (packed_weight_ == nullptr) {
@ -141,6 +142,8 @@ int ConvolutionDepthwiseSWFp16CPUKernel::Run() {
auto input_ptr = reinterpret_cast<float16_t *>(in_tensors_.at(0)->data_c());
auto output_ptr = reinterpret_cast<float16_t *>(out_tensors_.at(0)->data_c());
MS_ASSERT(input_ptr != nullptr);
MS_ASSERT(output_ptr != nullptr);
if (input_ptr == nullptr || output_ptr == nullptr) {
MS_LOG(ERROR) << "Convolution depthwise Fp16 get null tensor data!";
return RET_ERROR;

View File

@ -116,6 +116,8 @@ int ConvolutionFP16CPUKernel::ReSize() {
int ConvolutionFP16CPUKernel::RunImpl(int task_id) {
auto input_ptr = reinterpret_cast<float16_t *>(in_tensors_.at(0)->data_c());
auto output_ptr = reinterpret_cast<float16_t *>(out_tensors_.at(0)->data_c());
MS_ASSERT(input_ptr != nullptr);
MS_ASSERT(output_ptr != nullptr);
if (input_ptr == nullptr || output_ptr == nullptr) {
MS_LOG(ERROR) << "Convolution Fp16 get null tensor data!";
return RET_ERROR;

View File

@ -155,11 +155,16 @@ int ConvolutionWinogradFP16CPUKernel::Init() {
return RET_OK;
}
void ConvolutionWinogradFP16CPUKernel::AdjustNumberOfThread() {
int ConvolutionWinogradFP16CPUKernel::AdjustNumberOfThread() {
auto out_tensor = out_tensors_.front();
int cal_plane = UP_DIV(out_tensor->Height(), output_unit_) * UP_DIV(out_tensor->Width(), output_unit_);
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(cal_plane, C8NUM));
if (thread_count_ <= 0) {
MS_LOG(ERROR) << "thread_count_ must be greater than 0!";
return RET_ERROR;
}
conv_param_->thread_num_ = thread_count_;
return RET_OK;
}
int ConvolutionWinogradFP16CPUKernel::ReSize() {
@ -171,20 +176,26 @@ int ConvolutionWinogradFP16CPUKernel::ReSize() {
ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionBase init failed.";
return RET_ERROR;
return ret;
}
ret = ConfigInputOutput();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConfigInputOutput failed.";
return RET_ERROR;
return ret;
}
ret = AdjustNumberOfThread();
if (ret != RET_OK) {
MS_LOG(ERROR) << "AdjustNumberOfThread failed.";
return ret;
}
AdjustNumberOfThread();
return RET_OK;
}
int ConvolutionWinogradFP16CPUKernel::RunImpl(int task_id) {
auto input_ptr = reinterpret_cast<float16_t *>(in_tensors_.at(0)->data_c());
auto output_ptr = reinterpret_cast<float16_t *>(out_tensors_.at(0)->data_c());
MS_ASSERT(input_ptr != nullptr);
MS_ASSERT(output_ptr != nullptr);
if (input_ptr == nullptr || output_ptr == nullptr) {
MS_LOG(ERROR) << "Convolution Winograd Fp16 get null tensor data!";
return RET_ERROR;

View File

@ -51,7 +51,7 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseCPUKernel {
int InitTmpBuffer();
int ConfigInputOutput();
int WinogradFilterTransformFp16(const float16_t *weight_data, float *matrix_g, float *matrix_gt, int oc_block);
void AdjustNumberOfThread();
int AdjustNumberOfThread();
private:
void FreeTmpBuffer() {
@ -72,8 +72,8 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseCPUKernel {
col_buffer_ = nullptr;
}
}
int kernel_unit_;
int input_unit_;
int kernel_unit_ = 0;
int input_unit_ = 0;
int output_unit_;
void *origin_weight_; // do not free
void *origin_bias_; // do not free
@ -83,10 +83,10 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseCPUKernel {
float16_t *trans_weight_ = nullptr;
float16_t *col_buffer_ = nullptr;
TmpBufferAddressFp16 tmp_buffer_address_list_[4];
InputTransFp16Func in_func_;
OutputTransFp16Func out_func_;
int col_tile_;
int row_tile_;
InputTransFp16Func in_func_ = nullptr;
OutputTransFp16Func out_func_ = nullptr;
int col_tile_ = 0;
int row_tile_ = 0;
};
} // namespace mindspore::kernel

View File

@ -73,6 +73,7 @@ int DeconvolutionDepthwiseFp16CPUKernel::InitWeightBias() {
auto weight_tensor = in_tensors_.at(kWeightIndex);
int OC8 = UP_DIV(weight_tensor->Batch(), C8NUM);
auto origin_weight = reinterpret_cast<float16_t *>(weight_tensor->data_c());
MS_ASSERT(origin_weight != nullptr);
int pack_weight_size = C8NUM * OC8 * weight_tensor->Height() * weight_tensor->Width();
packed_weight_ = reinterpret_cast<float16_t *>(malloc(pack_weight_size * sizeof(float16_t)));
@ -118,8 +119,12 @@ int DeconvolutionDepthwiseFp16CPUKernel::Init() {
}
int DeconvolutionDepthwiseFp16CPUKernel::ReSize() {
InitSlideParam();
auto ret = ConvolutionBaseCPUKernel::Init();
auto ret = InitSlideParam();
if (ret != RET_OK) {
MS_LOG(ERROR) << "InitSlideParam failed!";
return ret;
}
ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {
return ret;
}
@ -156,6 +161,8 @@ int DeconvolutionDepthwiseFp16CPUKernel::Run() {
auto input_ptr = reinterpret_cast<float16_t *>(in_tensors_.at(0)->data_c());
auto output_ptr = reinterpret_cast<float16_t *>(out_tensors_.at(0)->data_c());
MS_ASSERT(input_ptr != nullptr);
MS_ASSERT(output_ptr != nullptr);
if (input_ptr == nullptr || output_ptr == nullptr) {
MS_LOG(ERROR) << "Deconvolution depthwise Fp16 get null tensor data!";
return RET_ERROR;

View File

@ -39,8 +39,11 @@ DeConvolutionFp16CPUKernel::~DeConvolutionFp16CPUKernel() {
}
int DeConvolutionFp16CPUKernel::ReSize() {
ConvolutionBaseCPUKernel::Init();
auto ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionBaseCPUKernel Init error!";
return ret;
}
int error_code = InitParam();
if (error_code != RET_OK) {
MS_LOG(ERROR) << "deconv InitParam error!";
@ -70,7 +73,7 @@ int DeConvolutionFp16CPUKernel::InitWeightBias() {
}
if (in_tensors_.at(kBiasIndex)->shape().size() == 1 &&
in_tensors_.at(kBiasIndex)->DimensionSize(0) == output_channel) {
memcpy(bias_data_, in_tensors_.at(2)->data_c(), output_channel * sizeof(float16_t));
memcpy(bias_data_, in_tensors_.at(kBiasIndex)->data_c(), output_channel * sizeof(float16_t));
} else {
MS_LOG(ERROR) << "unsupported bias shape for deconv!";
return RET_ERROR;
@ -88,8 +91,8 @@ int DeConvolutionFp16CPUKernel::InitWeightBias() {
MS_LOG(ERROR) << "deconv fp16 kernel require fp16 weight";
return RET_ERROR;
}
PackNHWCFp16ToC8HWN8Fp16(reinterpret_cast<float16_t *>(in_tensors_.at(1)->data_c()), pack_weight_, input_channel,
kernel_w * kernel_h, output_channel);
PackNHWCFp16ToC8HWN8Fp16(reinterpret_cast<float16_t *>(in_tensors_.at(kWeightIndex)->data_c()), pack_weight_,
input_channel, kernel_w * kernel_h, output_channel);
return RET_OK;
}
@ -199,6 +202,8 @@ int DeConvolutionFp16CPUKernel::Init() {
int DeConvolutionFp16CPUKernel::Run() {
auto input_ptr = reinterpret_cast<float16_t *>(in_tensors_.at(0)->data_c());
auto output_ptr = reinterpret_cast<float16_t *>(out_tensors_.at(0)->data_c());
MS_ASSERT(input_ptr != nullptr);
MS_ASSERT(output_ptr != nullptr);
if (input_ptr == nullptr || output_ptr == nullptr) {
MS_LOG(ERROR) << "DeConvolution Fp16 get null tensor data!";
return RET_ERROR;

View File

@ -315,7 +315,7 @@ int DeConvWinogradFp16CPUKernel::InitDataParam() {
/* unit data : weight & winograd data */
auto weight_tensor = in_tensors_.at(kWeightIndex);
auto origin_weight = reinterpret_cast<float16_t *>(weight_tensor->data_c());
MS_ASSERT(origin_weight != nullptr);
for (int i = 0; i < deconv_param_->compute_size_; i++) {
DeConvComputeUnit *unit = &deconv_param_->compute_units_[i];
auto ret = PackDeConvWgDataFp16(origin_weight, unit, conv_param_, deconv_param_);
@ -341,8 +341,16 @@ int DeConvWinogradFp16CPUKernel::InitDataParam() {
int DeConvWinogradFp16CPUKernel::ReSize() {
FreeResizeBuf();
ConvolutionBaseCPUKernel::Init();
InitParameter();
auto ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionBaseCPUKernel init failed!";
return ret;
}
ret = InitParameter();
if (ret != RET_OK) {
MS_LOG(ERROR) << "InitParameter failed!";
return ret;
}
return RET_OK;
}
@ -379,6 +387,8 @@ int DeConvWinogradFp16CPUKernel::Init() {
int DeConvWinogradFp16CPUKernel::Run() {
auto input_ptr = reinterpret_cast<float16_t *>(in_tensors_.at(0)->data_c());
auto output_ptr = reinterpret_cast<float16_t *>(out_tensors_.at(0)->data_c());
MS_ASSERT(input_ptr != nullptr);
MS_ASSERT(output_ptr != nullptr);
if (input_ptr == nullptr || output_ptr == nullptr) {
MS_LOG(ERROR) << "Deconvolution Winograd Fp16 get null tensor data!";
return RET_ERROR;
@ -389,12 +399,19 @@ int DeConvWinogradFp16CPUKernel::Run() {
nhwc_output_ = output_ptr + batch_index * deconv_param_->output_plane_ * conv_param_->output_channel_;
::memset(nc4hw4_output_, 0, deconv_param_->output_plane_ * deconv_param_->oc_div4_ * C4NUM * sizeof(float16_t));
static_cast<const lite::InnerContext *>(this->context_)
->thread_pool_->ParallelLaunch(DeConvWgFp16Run, this, deconv_param_->thread_num_);
auto ret = static_cast<const lite::InnerContext *>(this->context_)
->thread_pool_->ParallelLaunch(DeConvWgFp16Run, this, deconv_param_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "DeConvWgFp16Run failed!";
return ret;
}
// post bias activate and nhwc
static_cast<const lite::InnerContext *>(this->context_)
->thread_pool_->ParallelLaunch(DeConvWgPostFp16Run, this, thread_num_hw_);
ret = static_cast<const lite::InnerContext *>(this->context_)
->thread_pool_->ParallelLaunch(DeConvWgPostFp16Run, this, thread_num_hw_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "DeConvWgPostFp16Run failed!";
return ret;
}
}
return RET_OK;

View File

@ -47,15 +47,15 @@ class DeConvWinogradFp16CPUKernel : public ConvolutionBaseCPUKernel {
void FreeResizeBuf();
private:
DeConvParam *deconv_param_;
DeConvParam *deconv_param_ = nullptr;
std::mutex lock_;
float16_t *nhwc_input_ = nullptr;
float16_t *nhwc_output_ = nullptr;
float16_t *nc4hw4_output_ = nullptr;
float16_t *tile_input_ = nullptr;
float16_t *tile_output_ = nullptr;
int thread_num_hw_;
int thread_stride_hw_;
int thread_num_hw_ = 0;
int thread_stride_hw_ = 0;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_DECONVOLUTION_WINOGRAD_H_

View File

@ -27,6 +27,7 @@ int GroupConvolutionFP16CPUKernel::SeparateInput(int group_id) {
int sub_in_channel = conv_param_->input_channel_;
int ori_in_channel = sub_in_channel * group_num_;
auto sub_in_data = static_cast<lite::Tensor *>(group_convs_.at(group_id)->in_tensors().front())->data_c();
MS_ASSERT(sub_in_data != nullptr);
auto in_data_type = in_tensors_.front()->data_type();
auto sub_in_data_type = group_convs_.at(group_id)->in_tensors().front()->data_type();
if (in_data_type != sub_in_data_type) {

View File

@ -63,7 +63,6 @@ void MatmulBaseFP16CPUKernel::FreeResizeBufB() {
void MatmulBaseFP16CPUKernel::InitParameter() {
params_->a_const_ = (in_tensors_[0]->data_c() != nullptr);
params_->b_const_ = (in_tensors_[1]->data_c() != nullptr);
return;
}
int MatmulBaseFP16CPUKernel::InitBias() {
@ -80,6 +79,7 @@ int MatmulBaseFP16CPUKernel::InitBias() {
MS_LOG(ERROR) << "Matmul fp16 only support fp16 weight";
return RET_ERROR;
} else if (bias_tensor->data_type() == kNumberTypeFloat16) {
MS_ASSERT(bias_tensor->data_c() != nullptr);
memcpy(bias_ptr_, bias_tensor->data_c(), bias_tensor->ElementsNum() * sizeof(float16_t));
} else {
MS_LOG(ERROR) << "Unsupported bias data type : " << bias_tensor->data_type();
@ -116,7 +116,6 @@ void MatmulBaseFP16CPUKernel::ResizeParameter() {
params_->row_align_ = UP_ROUND(params_->row_, row_tile_);
params_->col_align_ = UP_ROUND(params_->col_, C8NUM);
}
return;
}
int MatmulBaseFP16CPUKernel::InitBufferA() {

View File

@ -106,10 +106,18 @@ int Convolution1x1CPUKernel::InitConv1x1Param() {
if ((matmul_param_->row_ > (row_tile_ * op_parameter_->thread_num_)) && (matmul_param_->row_ > matmul_param_->col_)) {
multi_thread_by_hw_ = true;
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, row_tile_));
if (thread_count_ <= 0) {
MS_LOG(ERROR) << "thread_count_ must be greater than 0!";
return RET_ERROR;
}
thread_stride_ = UP_DIV(UP_DIV(matmul_param_->row_, row_tile_), thread_count_) * row_tile_;
} else {
multi_thread_by_hw_ = false;
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, col_tile_));
if (thread_count_ <= 0) {
MS_LOG(ERROR) << "thread_count_ must be greater than 0!";
return RET_ERROR;
}
thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, col_tile_), thread_count_) * col_tile_;
}
@ -223,8 +231,10 @@ int Convolution1x1RunHw(void *cdata, int task_id, float lhs_scale, float rhs_sca
}
int Convolution1x1CPUKernel::Run() {
auto src_in = reinterpret_cast<float *>(in_tensors_[0]->MutableData());
auto src_out = reinterpret_cast<float *>(out_tensors_[0]->MutableData());
auto src_in = reinterpret_cast<float *>(in_tensors_[0]->data_c());
auto src_out = reinterpret_cast<float *>(out_tensors_[0]->data_c());
MS_ASSERT(src_in != nullptr);
MS_ASSERT(src_out != nullptr);
int pack_input_size = multi_thread_by_hw_ ? (thread_count_ * row_tile_ * matmul_param_->deep_)
: (matmul_param_->row_align_ * matmul_param_->deep_);
pack_input_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(pack_input_size * sizeof(float)));
@ -270,20 +280,22 @@ void Convolution1x1CPUKernel::PackWeight() {
int size = input_channel * UP_ROUND(output_channel, col_tile_) * sizeof(float);
int down_size = input_channel * DOWN_DIV(output_channel, col_tile_) * col_tile_ * sizeof(float);
memset(reinterpret_cast<char *>(weight_ptr_) + down_size, 0, size - down_size);
MS_ASSERT(filter_tensor->data_c() != nullptr);
#ifdef ENABLE_AVX
RowMajor2Col16Major(reinterpret_cast<float *>(filter_tensor->MutableData()), weight_ptr_, output_channel,
input_channel);
RowMajor2Col16Major(reinterpret_cast<float *>(filter_tensor->data_c()), weight_ptr_, output_channel, input_channel);
#elif defined(ENABLE_ARM32)
RowMajor2Col4Major(reinterpret_cast<float *>(filter_tensor->MutableData()), weight_ptr_, output_channel,
input_channel);
RowMajor2Col4Major(reinterpret_cast<float *>(filter_tensor->data_c()), weight_ptr_, output_channel, input_channel);
#else
RowMajor2Col8Major(reinterpret_cast<float *>(filter_tensor->MutableData()), weight_ptr_, output_channel,
input_channel);
RowMajor2Col8Major(reinterpret_cast<float *>(filter_tensor->data_c()), weight_ptr_, output_channel, input_channel);
#endif
}
int Convolution1x1CPUKernel::Eval() {
InnerKernel::Eval();
auto ret = InnerKernel::Eval();
if (ret != RET_OK) {
MS_LOG(ERROR) << "eval failed!";
return ret;
}
if (is_trainable()) {
PackWeight();
}

View File

@ -43,6 +43,7 @@ float *ConvolutionDelegateCPUKernel::CopyData(lite::Tensor *tensor) {
MS_LOG(ERROR) << "Malloc data failed.";
return nullptr;
}
MS_ASSERT(tensor->data_c() != nullptr);
memcpy(data, tensor->data_c(), tensor->Size());
return data;
}
@ -64,6 +65,7 @@ int ConvolutionDelegateCPUKernel::GetWeightAndBias() {
int ConvolutionDelegateCPUKernel::GetWeightData() {
if (InferShapeDone()) {
origin_weight_ = reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->data_c());
MS_ASSERT(origin_weight_ != nullptr);
return RET_OK;
}
origin_weight_ = CopyData(in_tensors_.at(kWeightIndex));
@ -79,6 +81,7 @@ int ConvolutionDelegateCPUKernel::GetBiasData() {
if (in_tensors_.size() == 3) {
if (InferShapeDone()) {
origin_bias_ = reinterpret_cast<float *>(in_tensors_.at(kBiasIndex)->data_c());
MS_ASSERT(origin_bias_ != nullptr);
return RET_OK;
} else {
origin_bias_ = CopyData(in_tensors_.at(kBiasIndex));

View File

@ -34,7 +34,7 @@ ConvolutionDepthwise3x3CPUKernel::~ConvolutionDepthwise3x3CPUKernel() {
int ConvolutionDepthwise3x3CPUKernel::InitWeightBias() {
// init weight: k, h, w, c; k == group == output_channel, c == 1
auto weight_tensor = in_tensors_[kWeightIndex];
auto origin_weight = reinterpret_cast<float *>(weight_tensor->MutableData());
auto origin_weight = reinterpret_cast<float *>(weight_tensor->data_c());
int channel = weight_tensor->Batch();
int c4 = UP_ROUND(channel, C4NUM);
int pack_weight_size = c4 * C12NUM;
@ -58,7 +58,7 @@ int ConvolutionDepthwise3x3CPUKernel::InitWeightBias() {
memset(bias_data_, 0, c4 * sizeof(float));
if (in_tensors_.size() == kInputSize2) {
auto bias_tensor = in_tensors_[kBiasIndex];
auto ori_bias = reinterpret_cast<float *>(bias_tensor->MutableData());
auto ori_bias = reinterpret_cast<float *>(bias_tensor->data_c());
memcpy(bias_data_, ori_bias, bias_tensor->ElementsNum() * sizeof(float));
}
@ -67,7 +67,7 @@ int ConvolutionDepthwise3x3CPUKernel::InitWeightBias() {
int ConvolutionDepthwise3x3CPUKernel::Init() {
auto ret = InitWeightBias();
if (ret != 0) {
if (ret != RET_OK) {
MS_LOG(ERROR) << "Convolution depthwise 3x3 fp32 InitWeightBias failed.";
return RET_ERROR;
}
@ -78,7 +78,11 @@ int ConvolutionDepthwise3x3CPUKernel::Init() {
}
int ConvolutionDepthwise3x3CPUKernel::ReSize() {
ConvolutionBaseCPUKernel::Init();
auto ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionBaseCPUKernel::Init() return is:" << ret;
return ret;
}
conv_param_->thread_num_ = MSMIN(thread_count_, conv_param_->output_h_);
return RET_OK;
}
@ -116,15 +120,19 @@ int ConvolutionDepthwise3x3CPUKernel::Run() {
}
if (IsTrain() && is_trainable()) {
InitWeightBias();
if (InitWeightBias() != RET_OK) {
ctx_->allocator->Free(buffer_);
MS_LOG(ERROR) << "Convolution depthwise 3x3 run InitWeightBias failed.";
return RET_ERROR;
}
}
auto input_tensor = in_tensors_.at(kInputIndex);
input_ptr_ = reinterpret_cast<float *>(input_tensor->data_c());
MS_ASSERT(input_ptr_ != nullptr);
auto output_tensor = out_tensors_.at(kOutputIndex);
output_ptr_ = reinterpret_cast<float *>(output_tensor->data_c());
MS_ASSERT(output_ptr_ != nullptr);
auto ret = static_cast<const lite::InnerContext *>(this->context_)
->thread_pool_->ParallelLaunch(ConvDw3x3Run, this, conv_param_->thread_num_);
ctx_->allocator->Free(buffer_);
@ -136,9 +144,16 @@ int ConvolutionDepthwise3x3CPUKernel::Run() {
}
int ConvolutionDepthwise3x3CPUKernel::Eval() {
InnerKernel::Eval();
auto ret = InnerKernel::Eval();
if (ret != RET_OK) {
MS_LOG(ERROR) << "eval failed!";
return ret;
}
if (is_trainable()) {
InitWeightBias();
if (InitWeightBias() != RET_OK) {
MS_LOG(ERROR) << "Convolution depthwise 3x3 fp32 Eval:InitWeightBias failed.";
return RET_ERROR;
}
}
return RET_OK;
}

View File

@ -32,7 +32,8 @@ ConvolutionDepthwiseCPUKernel::~ConvolutionDepthwiseCPUKernel() {
int ConvolutionDepthwiseCPUKernel::InitWeightBias() {
// init weight: k, h, w, c; k == group == output_channel, c == 1
auto weight_tensor = in_tensors_.at(kWeightIndex);
auto origin_weight = reinterpret_cast<float *>(weight_tensor->MutableData());
auto origin_weight = reinterpret_cast<float *>(weight_tensor->data_c());
MS_ASSERT(origin_weight != nullptr);
int channel = weight_tensor->Batch();
int pack_weight_size = weight_tensor->Batch() * weight_tensor->Height() * weight_tensor->Width();
if (pack_weight_size >= std::numeric_limits<int>::max() / static_cast<int>(sizeof(float))) {
@ -55,7 +56,7 @@ int ConvolutionDepthwiseCPUKernel::InitWeightBias() {
memset(bias_data_, 0, channel * sizeof(float));
if (in_tensors_.size() == kInputSize2) {
auto bias_tensor = in_tensors_[kBiasIndex];
auto ori_bias = reinterpret_cast<float *>(bias_tensor->MutableData());
auto ori_bias = reinterpret_cast<float *>(bias_tensor->data_c());
memcpy(bias_data_, ori_bias, bias_tensor->ElementsNum() * sizeof(float));
}
@ -64,7 +65,7 @@ int ConvolutionDepthwiseCPUKernel::InitWeightBias() {
int ConvolutionDepthwiseCPUKernel::Init() {
auto ret = InitWeightBias();
if (ret != 0) {
if (ret != RET_OK) {
MS_LOG(ERROR) << "Convolution depthwise fp32 InitWeightBias failed.";
return RET_ERROR;
}
@ -75,8 +76,16 @@ int ConvolutionDepthwiseCPUKernel::Init() {
}
int ConvolutionDepthwiseCPUKernel::ReSize() {
ConvolutionBaseCPUKernel::Init();
auto ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionBaseCPUKernel::Init() return is:" << ret;
return ret;
}
conv_param_->thread_num_ = MSMIN(thread_count_, conv_param_->output_h_);
if (conv_param_->thread_num_ <= 0) {
MS_LOG(ERROR) << "conv_param_->thread_num_ must be greater than 0!";
return RET_ERROR;
}
return RET_OK;
}
@ -101,10 +110,11 @@ int ConvolutionDepthwiseCPUKernel::Run() {
}
auto input_tensor = in_tensors_.at(kInputIndex);
input_ptr_ = reinterpret_cast<float *>(input_tensor->MutableData());
input_ptr_ = reinterpret_cast<float *>(input_tensor->data_c());
MS_ASSERT(input_ptr_ != nullptr);
auto output_tensor = out_tensors_.at(kOutputIndex);
output_ptr_ = reinterpret_cast<float *>(output_tensor->MutableData());
output_ptr_ = reinterpret_cast<float *>(output_tensor->data_c());
MS_ASSERT(output_ptr_ != nullptr);
auto ret = static_cast<const lite::InnerContext *>(this->context_)
->thread_pool_->ParallelLaunch(ConvDwRun, this, conv_param_->thread_num_);
@ -117,13 +127,19 @@ int ConvolutionDepthwiseCPUKernel::Run() {
void ConvolutionDepthwiseCPUKernel::PackWeight() {
auto weight_tensor = in_tensors_.at(kWeightIndex);
auto origin_weight = reinterpret_cast<float *>(weight_tensor->MutableData());
auto origin_weight = reinterpret_cast<float *>(weight_tensor->data_c());
MS_ASSERT(origin_weight != nullptr);
PackWeightKHWToHWKFp32(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(),
weight_tensor->Batch());
}
int ConvolutionDepthwiseCPUKernel::Eval() {
InnerKernel::Eval();
auto ret = InnerKernel::Eval();
if (ret != RET_OK) {
MS_LOG(ERROR) << "eval failed!";
return ret;
}
if (is_trainable()) {
PackWeight();
}

View File

@ -40,7 +40,8 @@ ConvolutionDepthwiseIndirectCPUKernel::~ConvolutionDepthwiseIndirectCPUKernel()
int ConvolutionDepthwiseIndirectCPUKernel::InitWeightBias() {
// init weight: o, h, w, i; o == group, i == 1
auto weight_tensor = in_tensors_[kWeightIndex];
auto origin_weight = reinterpret_cast<float *>(weight_tensor->MutableData());
auto origin_weight = reinterpret_cast<float *>(weight_tensor->data_c());
MS_ASSERT(origin_weight != nullptr);
#ifdef ENABLE_AVX
int div_flag = C8NUM;
#else
@ -70,7 +71,7 @@ int ConvolutionDepthwiseIndirectCPUKernel::InitWeightBias() {
if (in_tensors_.size() == kInputSize2) {
auto bias_tensor = in_tensors_[kBiasIndex];
auto ori_bias = reinterpret_cast<float *>(bias_tensor->MutableData());
auto ori_bias = reinterpret_cast<float *>(bias_tensor->data_c());
memcpy(bias_data_, ori_bias, bias_tensor->ElementsNum() * sizeof(float));
} else {
memset(bias_data_, 0, batch_flag * div_flag * sizeof(float));
@ -117,13 +118,21 @@ int ConvolutionDepthwiseIndirectCPUKernel::ReSize() {
free(indirect_buffer_);
indirect_buffer_ = nullptr;
}
ConvolutionBaseCPUKernel::Init();
auto ret = MallocIndirectBuffer();
auto ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionBaseCPUKernel::Init() return is:" << ret;
return ret;
}
ret = MallocIndirectBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionDepthwiseIndirect MallocIndirectBuffer failed";
return RET_ERROR;
}
conv_param_->thread_num_ = MSMIN(thread_count_, conv_param_->output_h_);
if (conv_param_->thread_num_ <= 0) {
MS_LOG(ERROR) << "conv_param_->thread_num_ must be greater than 0!";
return RET_ERROR;
}
return RET_OK;
}
@ -162,6 +171,7 @@ int ConvolutionDepthwiseIndirectCPUKernel::MallocPackedInput() {
int ConvolutionDepthwiseIndirectCPUKernel::Run() {
auto input_tensor = in_tensors_.at(kInputIndex);
auto input_ptr = reinterpret_cast<float *>(input_tensor->data_c());
MS_ASSERT(input_ptr != nullptr);
#ifdef ENABLE_AVX
int div_flag = C8NUM;
#else
@ -190,7 +200,7 @@ int ConvolutionDepthwiseIndirectCPUKernel::Run() {
auto output_tensor = out_tensors_.at(kOutputIndex);
output_ptr_ = reinterpret_cast<float *>(output_tensor->data_c());
MS_ASSERT(output_ptr_ != nullptr);
ConvDwInitIndirection(indirect_buffer_, packed_input_, zero_ptr_, conv_param_, step_h, step_w);
auto ret = static_cast<const lite::InnerContext *>(this->context_)
@ -207,7 +217,8 @@ int ConvolutionDepthwiseIndirectCPUKernel::Run() {
void ConvolutionDepthwiseIndirectCPUKernel::PackWeight() {
auto weight_tensor = in_tensors_[kWeightIndex];
auto origin_weight = reinterpret_cast<float *>(weight_tensor->MutableData());
auto origin_weight = reinterpret_cast<float *>(weight_tensor->data_c());
MS_ASSERT(origin_weight != nullptr);
#ifdef ENABLE_AVX
PackDepthwiseIndirectWeightC8Fp32(origin_weight, packed_weight_, weight_tensor->Height(), weight_tensor->Width(),
weight_tensor->Batch());
@ -218,7 +229,11 @@ void ConvolutionDepthwiseIndirectCPUKernel::PackWeight() {
}
int ConvolutionDepthwiseIndirectCPUKernel::Eval() {
InnerKernel::Eval();
auto ret = InnerKernel::Eval();
if (ret != RET_OK) {
MS_LOG(ERROR) << "eval failed!";
return ret;
}
if (is_trainable()) {
PackWeight();
}

View File

@ -36,7 +36,8 @@ ConvolutionDepthwiseSWCPUKernel::~ConvolutionDepthwiseSWCPUKernel() {
int ConvolutionDepthwiseSWCPUKernel::InitWeightBias() {
// init weight: o, h, w, i; o == group, i == 1
auto weight_tensor = in_tensors_.at(kWeightIndex);
auto origin_weight = reinterpret_cast<float *>(weight_tensor->MutableData());
auto origin_weight = reinterpret_cast<float *>(weight_tensor->data_c());
MS_ASSERT(origin_weight != nullptr);
int OC4 = UP_DIV(weight_tensor->Batch(), C4NUM);
int pack_weight_size = C4NUM * OC4 * weight_tensor->Height() * weight_tensor->Width();
@ -62,7 +63,7 @@ int ConvolutionDepthwiseSWCPUKernel::InitWeightBias() {
memset(bias_data_, 0, malloc_size * sizeof(float));
if (in_tensors_.size() == kInputSize2) {
auto bias_tensor = in_tensors_.at(kBiasIndex);
auto ori_bias = reinterpret_cast<float *>(bias_tensor->MutableData());
auto ori_bias = reinterpret_cast<float *>(bias_tensor->data_c());
memcpy(bias_data_, ori_bias, bias_tensor->ElementsNum() * sizeof(float));
}
@ -111,9 +112,17 @@ int ConvolutionDepthwiseSWCPUKernel::Init() {
}
int ConvolutionDepthwiseSWCPUKernel::ReSize() {
ConvolutionBaseCPUKernel::Init();
auto ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionBaseCPUKernel::Init() return is:" << ret;
return ret;
}
InitSlidingParamConvDw(sliding_, conv_param_, C4NUM);
conv_param_->thread_num_ = MSMIN(thread_count_, conv_param_->output_h_);
if (conv_param_->thread_num_ <= 0) {
MS_LOG(ERROR) << "conv_param_->thread_num_ must be greater than 0!";
return RET_ERROR;
}
return RET_OK;
}
@ -146,8 +155,8 @@ int ConvolutionDepthwiseSWCPUKernel::Run() {
}
auto input_tensor = in_tensors_.at(kInputIndex);
auto input_ptr = reinterpret_cast<float *>(input_tensor->MutableData());
auto input_ptr = reinterpret_cast<float *>(input_tensor->data_c());
MS_ASSERT(input_ptr != nullptr);
if (need_align_) {
PackNHWCToNHWC4Fp32(input_ptr, packed_input_, conv_param_->input_batch_,
conv_param_->input_h_ * conv_param_->input_w_, conv_param_->input_channel_);
@ -156,8 +165,8 @@ int ConvolutionDepthwiseSWCPUKernel::Run() {
}
auto output_tensor = out_tensors_.at(kOutputIndex);
auto output_ptr = reinterpret_cast<float *>(output_tensor->MutableData());
auto output_ptr = reinterpret_cast<float *>(output_tensor->data_c());
MS_ASSERT(output_ptr != nullptr);
if (!need_align_) {
packed_output_ = output_ptr;
}
@ -187,13 +196,18 @@ void ConvolutionDepthwiseSWCPUKernel::FreePackedInputOutput() {
void ConvolutionDepthwiseSWCPUKernel::PackWeight() {
auto weight_tensor = in_tensors_.at(kWeightIndex);
auto origin_weight = reinterpret_cast<float *>(weight_tensor->MutableData());
auto origin_weight = reinterpret_cast<float *>(weight_tensor->data_c());
MS_ASSERT(origin_weight != nullptr);
PackNCHWToNC4HW4Fp32(origin_weight, packed_weight_, 1, weight_tensor->Height() * weight_tensor->Width(),
weight_tensor->Batch());
}
int ConvolutionDepthwiseSWCPUKernel::Eval() {
InnerKernel::Eval();
auto ret = InnerKernel::Eval();
if (ret != RET_OK) {
MS_LOG(ERROR) << "eval failed!";
return ret;
}
if (is_trainable()) {
PackWeight();
}

View File

@ -41,7 +41,8 @@ ConvolutionDepthwiseSWCPUKernelX86::~ConvolutionDepthwiseSWCPUKernelX86() {
int ConvolutionDepthwiseSWCPUKernelX86::InitWeightBias() {
// init weight: o, h, w, i; o == group, i == 1
auto weight_tensor = in_tensors_.at(kWeightIndex);
origin_weight_ = reinterpret_cast<float *>(weight_tensor->MutableData());
origin_weight_ = reinterpret_cast<float *>(weight_tensor->data_c());
MS_ASSERT(origin_weight_ != nullptr);
int oc_algin = UP_DIV(weight_tensor->Batch(), oc_tile_);
int pack_weight_size = oc_algin * oc_tile_ * weight_tensor->Height() * weight_tensor->Width();
@ -55,7 +56,7 @@ int ConvolutionDepthwiseSWCPUKernelX86::InitWeightBias() {
if (in_tensors_.size() == kInputSize2) {
auto bias_size = oc_algin * oc_tile_;
auto bias_tensor = in_tensors_.at(kBiasIndex);
auto ori_bias = reinterpret_cast<float *>(bias_tensor->MutableData());
auto ori_bias = reinterpret_cast<float *>(bias_tensor->data_c());
packed_bias_ = reinterpret_cast<float *>(malloc(bias_size * sizeof(float)));
if (packed_bias_ == nullptr) {
MS_LOG(ERROR) << "Malloc bias_data buffer failed.";
@ -149,7 +150,8 @@ int ConvolutionDepthwiseSWCPUKernelX86::Run() {
}
auto input_tensor = in_tensors_.at(kInputIndex);
auto input_ptr = reinterpret_cast<float *>(input_tensor->MutableData());
auto input_ptr = reinterpret_cast<float *>(input_tensor->data_c());
MS_ASSERT(input_ptr != nullptr);
if (input_need_align_) {
PackNHWCToNHWCXFp32(input_ptr, packed_input_, conv_param_->input_batch_,
@ -159,7 +161,8 @@ int ConvolutionDepthwiseSWCPUKernelX86::Run() {
}
auto output_tensor = out_tensors_.at(kOutputIndex);
auto output_ptr = reinterpret_cast<float *>(output_tensor->MutableData());
auto output_ptr = reinterpret_cast<float *>(output_tensor->data_c());
MS_ASSERT(output_ptr != nullptr);
if (!output_need_align_) {
packed_output_ = output_ptr;
@ -198,7 +201,11 @@ void ConvolutionDepthwiseSWCPUKernelX86::PackWeight() {
}
int ConvolutionDepthwiseSWCPUKernelX86::Eval() {
InnerKernel::Eval();
auto ret = InnerKernel::Eval();
if (ret != RET_OK) {
MS_LOG(ERROR) << "eval failed!";
return ret;
}
if (is_trainable()) {
PackWeight();
}

View File

@ -191,7 +191,9 @@ int ConvolutionWinogradCPUKernel::ReSize() {
int ConvolutionWinogradCPUKernel::RunImpl(int task_id) {
auto input_tensor = in_tensors_.at(kInputIndex);
auto ori_input_data = reinterpret_cast<float *>(input_tensor->data_c());
MS_ASSERT(ori_input_data != nullptr);
auto output_data = reinterpret_cast<float *>(out_tensors_.front()->data_c());
MS_ASSERT(output_data != nullptr);
ConvWinogardFp32(ori_input_data, trans_weight_, reinterpret_cast<const float *>(bias_data_), output_data,
tmp_buffer_address_list_, task_id, conv_param_, in_func_, out_func_);
return RET_OK;
@ -215,7 +217,11 @@ int ConvolutionWinogradCPUKernel::Run() {
return RET_ERROR;
}
if (IsTrain() && is_trainable()) {
InitWeightBias();
ret = InitWeightBias();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init weight bias failed.";
return RET_ERROR;
}
}
ret = static_cast<const lite::InnerContext *>(this->context_)
@ -229,9 +235,17 @@ int ConvolutionWinogradCPUKernel::Run() {
}
int ConvolutionWinogradCPUKernel::Eval() {
InnerKernel::Eval();
auto ret = InnerKernel::Eval();
if (ret != RET_OK) {
MS_LOG(ERROR) << "eval failed!";
return ret;
}
if (is_trainable()) {
InitWeightBias();
ret = InitWeightBias();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init weight bias failed.";
return RET_ERROR;
}
}
return RET_OK;
}

View File

@ -48,7 +48,8 @@ int DeconvolutionDepthwiseCPUKernel::InitSlideParam() {
int DeconvolutionDepthwiseCPUKernel::InitWeightBias() {
// init weight: o, h, w, i; o == group, i == 1
auto weight_tensor = in_tensors_.at(kWeightIndex);
auto origin_weight = reinterpret_cast<float *>(weight_tensor->MutableData());
auto origin_weight = reinterpret_cast<float *>(weight_tensor->data_c());
MS_ASSERT(origin_weight != nullptr);
int OC4 = UP_DIV(weight_tensor->Batch(), C4NUM);
int pack_weight_size = C4NUM * OC4 * weight_tensor->Height() * weight_tensor->Width();
@ -67,7 +68,7 @@ int DeconvolutionDepthwiseCPUKernel::InitWeightBias() {
}
memset(bias_data_, 0, C4NUM * OC4 * sizeof(float));
if (in_tensors_.size() == kInputSize2) {
auto ori_bias = reinterpret_cast<float *>(in_tensors_.at(kBiasIndex)->MutableData());
auto ori_bias = reinterpret_cast<float *>(in_tensors_.at(kBiasIndex)->data_c());
memcpy(bias_data_, ori_bias, in_tensors_.at(kBiasIndex)->ElementsNum() * sizeof(float));
}
@ -117,8 +118,16 @@ int DeconvolutionDepthwiseCPUKernel::Init() {
}
int DeconvolutionDepthwiseCPUKernel::ReSize() {
InitSlideParam();
ConvolutionBaseCPUKernel::Init();
auto ret = InitSlideParam();
if (ret != RET_OK) {
MS_LOG(ERROR) << "InitSlideParam is failed!";
return ret;
}
ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionBaseCPUKernel init failed!";
return ret;
}
return RET_OK;
}
@ -152,8 +161,8 @@ int DeconvolutionDepthwiseCPUKernel::Run() {
}
auto input_tensor = in_tensors_.at(kInputIndex);
auto input_addr = reinterpret_cast<float *>(input_tensor->MutableData());
auto input_addr = reinterpret_cast<float *>(input_tensor->data_c());
MS_ASSERT(input_addr != nullptr);
if (need_align_) {
PackNHWCToNHWC4Fp32(input_addr, packed_input_, conv_param_->input_batch_,
conv_param_->input_h_ * conv_param_->input_w_, conv_param_->input_channel_);
@ -161,7 +170,8 @@ int DeconvolutionDepthwiseCPUKernel::Run() {
packed_input_ = input_addr;
}
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData());
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->data_c());
MS_ASSERT(output_addr != nullptr);
if (!need_align_) {
memset(output_addr, 0, out_tensors_.at(kOutputIndex)->ElementsNum() * sizeof(float));
packed_output_ = output_addr;

View File

@ -38,7 +38,11 @@ DeConvolutionCPUKernel::~DeConvolutionCPUKernel() {
}
int DeConvolutionCPUKernel::ReSize() {
ConvolutionBaseCPUKernel::Init();
auto ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionBaseCPUKernel init error!";
return ret;
}
int error_code = InitParam();
if (error_code != RET_OK) {
@ -64,7 +68,8 @@ int DeConvolutionCPUKernel::InitWeightBias() {
if (in_tensors_.size() == 3) {
if (in_tensors_.at(kBiasIndex)->shape().size() == 1 &&
in_tensors_.at(kBiasIndex)->DimensionSize(0) == output_channel) {
memcpy(bias_data_, in_tensors_.at(2)->MutableData(), output_channel * sizeof(float));
MS_ASSERT(in_tensors_.at(kBiasIndex)->data_c() != nullptr);
memcpy(bias_data_, in_tensors_.at(kBiasIndex)->data_c(), output_channel * sizeof(float));
} else {
MS_LOG(ERROR) << "unsupported bias shape for deconv!";
return RET_ERROR;
@ -78,7 +83,8 @@ int DeConvolutionCPUKernel::InitWeightBias() {
return RET_ERROR;
}
memset(weight_ptr_, 0, weight_pack_size);
PackNHWCToC8HWN8Fp32(reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()), weight_ptr_, input_channel,
MS_ASSERT(in_tensors_.at(kWeightIndex)->data_c() != nullptr);
PackNHWCToC8HWN8Fp32(reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->data_c()), weight_ptr_, input_channel,
kernel_w_ * kernel_h_, output_channel);
return RET_OK;
}
@ -206,9 +212,10 @@ int DeConvolutionCPUKernel::InitRunBuf() {
}
int DeConvolutionCPUKernel::Run() {
float *src_in = reinterpret_cast<float *>(in_tensors_[0]->MutableData());
float *src_out = reinterpret_cast<float *>(out_tensors_[0]->MutableData());
float *src_in = reinterpret_cast<float *>(in_tensors_[0]->data_c());
float *src_out = reinterpret_cast<float *>(out_tensors_[0]->data_c());
MS_ASSERT(src_in != nullptr);
MS_ASSERT(src_out != nullptr);
int error_code = InitRunBuf();
if (error_code != RET_OK) {
MS_LOG(ERROR) << "deconv fp32 InitRunBuf error! error_code[" << error_code << "]";

View File

@ -265,7 +265,8 @@ int DeConvolutionWinogradCPUKernel::InitComputeParam() {
int DeConvolutionWinogradCPUKernel::InitDataParam() {
auto weight_tensor = in_tensors_.at(kWeightIndex);
float *nhwc_weight = reinterpret_cast<float *>(weight_tensor->data_c());
auto nhwc_weight = reinterpret_cast<float *>(weight_tensor->data_c());
MS_ASSERT(nhwc_weight != nullptr);
/* unit data : weight & winograd data */
for (int i = 0; i < deconv_param_->compute_size_; i++) {
@ -286,6 +287,7 @@ int DeConvolutionWinogradCPUKernel::InitDataParam() {
if (in_tensors_.size() == 3 && in_tensors_.at(kBiasIndex)->shape().size() == 1 &&
in_tensors_.at(kBiasIndex)->DimensionSize(0) == conv_param_->output_channel_) {
auto bias_tensor = in_tensors_.at(kBiasIndex);
MS_ASSERT(bias_tensor->data_c() != nullptr);
memcpy(bias_data_, bias_tensor->data_c(), conv_param_->output_channel_ * sizeof(float));
}
return RET_OK;
@ -402,18 +404,29 @@ int DeConvolutionWinogradCPUKernel::Run() {
float *src_in = reinterpret_cast<float *>(in_tensors_[0]->data_c());
float *src_out = reinterpret_cast<float *>(out_tensors_[0]->data_c());
MS_ASSERT(src_in != nullptr);
MS_ASSERT(src_out != nullptr);
for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) {
nhwc_input_ = src_in + batch_index * deconv_param_->input_plane_ * conv_param_->input_channel_;
nhwc_output_ = src_out + batch_index * deconv_param_->output_plane_ * conv_param_->output_channel_;
::memset(nc4hw4_output_, 0, deconv_param_->output_plane_ * deconv_param_->oc_div4_ * C4NUM * sizeof(float));
static_cast<const lite::InnerContext *>(this->context_)
->thread_pool_->ParallelLaunch(DeConvWgFp32Run, this, deconv_param_->thread_num_);
ret = static_cast<const lite::InnerContext *>(this->context_)
->thread_pool_->ParallelLaunch(DeConvWgFp32Run, this, deconv_param_->thread_num_);
if (ret != RET_OK) {
FreeRunBuf();
MS_LOG(ERROR) << "DeConvWgFp32Run failed!";
return ret;
}
/* post bias activate and nhwc */
static_cast<const lite::InnerContext *>(this->context_)
->thread_pool_->ParallelLaunch(DeConvWgPostFp32Run, this, thread_num_hw_);
ret = static_cast<const lite::InnerContext *>(this->context_)
->thread_pool_->ParallelLaunch(DeConvWgPostFp32Run, this, thread_num_hw_);
if (ret != RET_OK) {
FreeRunBuf();
MS_LOG(ERROR) << "DeConvWgPostFp32Run failed!";
return ret;
}
}
FreeRunBuf();

View File

@ -77,8 +77,11 @@ int MatmulCPUKernel::ReSize() {
}
int MatmulCPUKernel::Run() {
MatmulFp32BaseCPUKernel::Run();
return RET_OK;
auto ret = MatmulFp32BaseCPUKernel::Run();
if (ret != RET_OK) {
MS_LOG(ERROR) << "MatmulFp32BaseCPUKernel failed!";
}
return ret;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_MatMul, LiteKernelCreator<MatmulCPUKernel>)

View File

@ -390,7 +390,11 @@ int MatmulFp32BaseCPUKernel::Run() {
if (RET_OK != InitBufferA()) {
return RET_ERROR;
}
InitMatrixA(a_ptr);
auto ret = InitMatrixA(a_ptr);
if (ret != RET_OK) {
MS_LOG(ERROR) << "InitMatrixA failed!";
return ret;
}
}
if (!params_->b_const_) {
auto b_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->data_c());
@ -398,7 +402,11 @@ int MatmulFp32BaseCPUKernel::Run() {
FreeResizeBufA();
return RET_ERROR;
}
InitMatrixB(b_ptr);
auto ret = InitMatrixB(b_ptr);
if (ret != RET_OK) {
MS_LOG(ERROR) << "InitMatrixB failed!";
return ret;
}
}
auto ret = InitTmpOutBuffer();