!3799 conv1x1 & deconv change

Merge pull request !3799 from ling/conv1x1
This commit is contained in:
mindspore-ci-bot 2020-08-03 16:01:44 +08:00 committed by Gitee
commit 8beb1b0fb5
31 changed files with 2262 additions and 301 deletions

View File

@ -165,8 +165,7 @@ OpParameter *PopulateFullconnectionParameter(const lite::Primitive *primitive) {
matmul_param->b_transpose_ = true;
matmul_param->a_transpose_ = false;
matmul_param->has_bias_ = param->hasBias();
matmul_param->minf_ = -FLT_MAX;
matmul_param->maxf_ = FLT_MAX;
matmul_param->act_type_ = ActType_No;
return reinterpret_cast<OpParameter *>(matmul_param);
}
@ -181,8 +180,7 @@ OpParameter *PopulateMatMulParameter(const lite::Primitive *primitive) {
matmul_param->b_transpose_ = param->transposeB();
matmul_param->a_transpose_ = param->transposeA();
matmul_param->has_bias_ = false;
matmul_param->minf_ = -FLT_MAX;
matmul_param->maxf_ = FLT_MAX;
matmul_param->act_type_ = ActType_No;
return reinterpret_cast<OpParameter *>(matmul_param);
}

View File

@ -146,28 +146,10 @@ int ConvolutionBaseCPUKernel::SetQuantParam() {
QuantizeRoundParameter(real_multiplier, &conv_quant_arg_->quant_multiplier_[0], &conv_quant_arg_->left_shift_[0],
&conv_quant_arg_->right_shift_[0]);
ComputeQuantOutRange(conv_param_);
CalculateActivationRangeQuantized(
conv_param_->is_relu_, conv_param_->is_relu6_, conv_param_->conv_quant_arg_.quant_args_[2][0].zp_,
conv_param_->conv_quant_arg_.quant_args_[2][0].scale_, &conv_param_->conv_quant_arg_.out_act_min_[0],
&conv_param_->conv_quant_arg_.out_act_max_[0]);
return RET_OK;
}
void ComputeQuantOutRange(ConvParameter *conv_param) {
int32_t min = std::numeric_limits<int8_t>::min();
int32_t max = std::numeric_limits<int8_t>::max();
float scale = conv_param->conv_quant_arg_.quant_args_[2][0].scale_;
int32_t zp = conv_param->conv_quant_arg_.quant_args_[2][0].zp_;
bool is_relu = conv_param->is_relu_;
bool is_relu6 = conv_param->is_relu6_;
int32_t quantized_zero = QuantizeToInt8(0, scale, zp);
int32_t quantized_six = QuantizeToInt8(6, scale, zp);
if (is_relu) {
min = min > quantized_zero ? min : quantized_zero;
} else if (is_relu6) {
min = min > quantized_zero ? min : quantized_zero;
max = max < quantized_six ? max : quantized_six;
} else {
// do nothing
}
conv_param->conv_quant_arg_.out_act_min_[0] = min;
conv_param->conv_quant_arg_.out_act_max_[0] = max;
}
} // namespace mindspore::kernel

View File

@ -38,7 +38,7 @@ class ConvolutionBaseCPUKernel : public LiteKernel {
public:
ConvolutionBaseCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx)
: LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->threadNum) {
: LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->threadNum) {
opParameter->thread_num_ = ctx->threadNum;
conv_param_ = reinterpret_cast<ConvParameter *>(opParameter);
}
@ -60,7 +60,6 @@ class ConvolutionBaseCPUKernel : public LiteKernel {
ConvParameter *conv_param_;
LayoutConvertor convert_func_;
};
void ComputeQuantOutRange(ConvParameter *conv_param);
bool CheckSupportFP16();
} // namespace mindspore::kernel

View File

@ -23,62 +23,71 @@ using mindspore::lite::RET_OK;
namespace mindspore::kernel {
Convolution1x1CPUKernel::~Convolution1x1CPUKernel() {
if (c4_output_ != nullptr) {
free(c4_output_);
c4_output_ = nullptr;
}
if (c4_input_ != nullptr) {
free(c4_input_);
c4_input_ = nullptr;
}
if (pre_trans_input_) {
free(input_ptr_);
input_ptr_ = nullptr;
}
if (tmp_ptr_ != nullptr) {
free(tmp_ptr_);
tmp_ptr_ = nullptr;
}
if (weight_ptr_ != nullptr) {
free(weight_ptr_);
weight_ptr_ = nullptr;
}
if (pack_input_ != nullptr) {
free(pack_input_);
pack_input_ = nullptr;
}
if (pack_output_ != nullptr) {
free(pack_output_);
pack_output_ = nullptr;
}
if (pre_trans_input_ && input_ptr_ != nullptr) {
free(input_ptr_);
input_ptr_ = nullptr;
}
delete matmul_param_;
}
int Convolution1x1CPUKernel::ReSize() { return RET_OK; }
int Convolution1x1CPUKernel::ReSize() {
if (pack_input_ != nullptr) {
free(pack_input_);
pack_input_ = nullptr;
}
if (pre_trans_input_ && input_ptr_ != nullptr) {
free(input_ptr_);
input_ptr_ = nullptr;
}
InitConv1x1MatmulParam();
InitConv1x1Param();
return RET_OK;
}
void Convolution1x1CPUKernel::InitConv1x1MatmulParam() {
matmul_param_ = new StrassenMatMulParameter();
matmul_param_->row_ = conv_param_->output_h_ * conv_param_->output_w_;
matmul_param_->col_ = UP_DIV(conv_param_->output_channel_, FP32_STRASSEN_UINT);
matmul_param_->deep_ = UP_DIV(conv_param_->input_channel_, FP32_STRASSEN_UINT);
matmul_param_->a_stride_ = matmul_param_->row_ * FP32_STRASSEN_UINT;
matmul_param_->b_stride_ = matmul_param_->deep_ * FP32_STRASSEN_WEIGHT_UINT;
matmul_param_->c_stride_ = matmul_param_->row_ * FP32_STRASSEN_UINT;
matmul_param_->col_ = conv_param_->output_channel_;
matmul_param_->deep_ = conv_param_->input_channel_;
matmul_param_->row_8_ = UP_ROUND(matmul_param_->row_, C8NUM);
matmul_param_->col_8_ = UP_ROUND(matmul_param_->col_, C8NUM);
matmul_param_->act_type_ = (conv_param_->is_relu6_) ? ActType_Relu6 : ActType_No;
matmul_param_->act_type_ = (conv_param_->is_relu_) ? ActType_Relu : matmul_param_->act_type_;
return;
}
int Convolution1x1CPUKernel::InitConv1x1BiasWeight() {
if (inputs_.size() == 3) {
bias_data_ = malloc(matmul_param_->col_ * C4NUM * sizeof(float));
bias_data_ = malloc(matmul_param_->col_8_ * sizeof(float));
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc bias_ptr_ error!";
return RET_ERROR;
}
memset(bias_data_, 0, matmul_param_->col_ * C4NUM * sizeof(float));
memset(bias_data_, 0, matmul_param_->col_8_ * sizeof(float));
memcpy(bias_data_, inputs_[2]->Data(), conv_param_->output_channel_ * sizeof(float));
} else {
bias_data_ = nullptr;
}
weight_ptr_ = reinterpret_cast<float *>(
malloc(matmul_param_->col_ * matmul_param_->deep_ * FP32_STRASSEN_WEIGHT_UINT * sizeof(float)));
weight_ptr_ = reinterpret_cast<float *>(malloc(matmul_param_->row_8_ * matmul_param_->col_8_ * sizeof(float)));
if (weight_ptr_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc weight_ptr_ error!";
return RET_ERROR;
}
memset(weight_ptr_, 0, matmul_param_->col_ * matmul_param_->deep_ * FP32_STRASSEN_WEIGHT_UINT * sizeof(float));
Pack1x1WeightFp32(reinterpret_cast<float *>(inputs_[1]->Data()), weight_ptr_, conv_param_);
memset(weight_ptr_, 0, matmul_param_->row_8_ * matmul_param_->col_8_ * sizeof(float));
RowMajor2Col8Major(reinterpret_cast<float *>(inputs_[1]->Data()), weight_ptr_, matmul_param_->col_,
matmul_param_->deep_);
return RET_OK;
}
@ -86,52 +95,43 @@ int Convolution1x1CPUKernel::InitConv1x1Param() {
pre_trans_input_ = (conv_param_->pad_h_ != 0 || conv_param_->pad_w_ != 0 || conv_param_->stride_h_ != 1 ||
conv_param_->stride_w_ != 1);
if (pre_trans_input_) {
input_ptr_ = reinterpret_cast<float *>(malloc(matmul_param_->a_stride_ * matmul_param_->deep_ * sizeof(float)));
input_ptr_ = reinterpret_cast<float *>(malloc(matmul_param_->row_ * matmul_param_->deep_ * sizeof(float)));
if (input_ptr_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc input_ptr_ error!";
return RET_MEMORY_FAILED;
}
memset(input_ptr_, 0, matmul_param_->a_stride_ * matmul_param_->deep_ * sizeof(float));
memset(input_ptr_, 0, matmul_param_->row_ * matmul_param_->deep_ * sizeof(float));
}
thread_hw_count_ = MSMIN(opParameter->thread_num_, matmul_param_->row_);
thread_hw_stride_ = UP_DIV(matmul_param_->row_, thread_hw_count_);
thread_count_ = MSMIN(opParameter->thread_num_, UP_DIV(matmul_param_->col_, C8NUM));
thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, C8NUM), thread_count_) * C8NUM;
thread_oc4_count_ = MSMIN(opParameter->thread_num_, matmul_param_->col_);
thread_oc_stride_ = UP_DIV(matmul_param_->col_, thread_oc4_count_) * C4NUM;
tmp_ptr_ = reinterpret_cast<float *>(malloc(matmul_param_->a_stride_ * matmul_param_->deep_ * sizeof(float)));
if (tmp_ptr_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc tmp_ptr_ error!";
return RET_MEMORY_FAILED;
}
c4_output_ =
reinterpret_cast<float *>(malloc(outputs_[0]->ElementsC4Num() / conv_param_->output_batch_ * sizeof(float)));
if (c4_output_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc c4_output_ error!";
pack_input_ = reinterpret_cast<float *>(malloc(matmul_param_->row_8_ * matmul_param_->deep_ * sizeof(float)));
if (pack_input_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc pack_input_ error!";
return RET_MEMORY_FAILED;
}
memset(pack_input_, 0, matmul_param_->row_8_ * matmul_param_->deep_ * sizeof(float));
c4_input_ =
reinterpret_cast<float *>(malloc(inputs_[0]->ElementsC4Num() / conv_param_->input_batch_ * sizeof(float)));
if (c4_input_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc c4_input_ error!";
pack_output_ = reinterpret_cast<float *>(malloc(matmul_param_->row_8_ * matmul_param_->col_8_ * sizeof(float)));
if (pack_output_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc pack_output_ error!";
return RET_MEMORY_FAILED;
}
memset(pack_output_, 0, matmul_param_->row_8_ * matmul_param_->col_8_ * sizeof(float));
return RET_OK;
}
void Convolution1x1CPUKernel::Pre1x1Trans(float *src_input, float *src_output) {
output_ptr_ = src_output;
PackNHWCToNC4HW4Fp32(src_input, c4_input_, 1, conv_param_->input_h_ * conv_param_->input_w_,
conv_param_->input_channel_);
if (!pre_trans_input_) {
input_ptr_ = c4_input_;
return;
if (pre_trans_input_) {
Conv1x1InputPackFp32(src_input, input_ptr_, conv_param_);
} else {
input_ptr_ = src_input;
}
Conv1x1InputPackFp32(c4_input_, input_ptr_, conv_param_);
RowMajor2Col8Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
return;
}
@ -152,53 +152,26 @@ int Convolution1x1CPUKernel::Init() {
return RET_OK;
}
int Convolution1x1CPUKernel::DoStrassen(int task_id) {
matmul_param_->row_ = MSMIN(thread_hw_stride_, matmul_param_->row_ - task_id * thread_hw_stride_);
if (matmul_param_->row_ <= 0) {
return RET_OK;
}
auto error_code = Conv1x1Fp32(input_ptr_ + task_id * thread_hw_stride_ * C4NUM, weight_ptr_,
c4_output_ + task_id * thread_hw_stride_ * C4NUM,
tmp_ptr_ + task_id * thread_hw_stride_ * matmul_param_->deep_ * C4NUM, *matmul_param_);
if (error_code != 0) {
MS_LOG(ERROR) << "DoStrassen error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
}
matmul_param_->row_ = conv_param_->output_h_ * conv_param_->output_w_;
return RET_OK;
}
int Convolution1x1StrassenRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
auto conv1x1 = reinterpret_cast<Convolution1x1CPUKernel *>(cdata);
auto error_code = conv1x1->DoStrassen(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Convolution1x1StrassenRun error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}
int Convolution1x1CPUKernel::DoPostFunc(int task_id) {
int cur_oc = MSMIN(thread_oc_stride_, conv_param_->output_channel_ - task_id * thread_oc_stride_);
int Convolution1x1CPUKernel::DoConv1x1(int task_id) {
int cur_oc = MSMIN(thread_stride_, matmul_param_->col_8_ - task_id * thread_stride_);
if (cur_oc <= 0) {
return RET_OK;
}
float *cur_bias =
(bias_data_ == nullptr) ? nullptr : reinterpret_cast<float *>(bias_data_) + task_id * thread_oc_stride_;
auto bias = (bias_data_ == nullptr) ? nullptr : reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id;
MatMul(pack_input_, weight_ptr_ + task_id * thread_stride_ * matmul_param_->deep_,
pack_output_ + task_id * thread_stride_ * matmul_param_->row_8_, bias, matmul_param_->act_type_,
matmul_param_->deep_, matmul_param_->row_8_, cur_oc);
PostConvFuncFp32(c4_output_ + matmul_param_->row_ * thread_oc_stride_ * task_id,
output_ptr_ + task_id * thread_oc_stride_, cur_bias, cur_oc, matmul_param_->row_,
conv_param_->output_channel_, conv_param_->is_relu_, conv_param_->is_relu6_);
return RET_OK;
}
int Convolution1x1PostFuncRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
int Convolution1x1Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
auto conv1x1 = reinterpret_cast<Convolution1x1CPUKernel *>(cdata);
auto error_code = conv1x1->DoPostFunc(task_id);
auto error_code = conv1x1->DoConv1x1(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Convolution1x1PostFuncRun error task_id[" << task_id << "] error_code[" << error_code << "]";
MS_LOG(ERROR) << "Convolution1x1Run error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
@ -209,20 +182,16 @@ int Convolution1x1CPUKernel::Run() {
auto src_out = reinterpret_cast<float *>(outputs_[0]->Data());
for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) {
Pre1x1Trans(src_in + batch_index * matmul_param_->deep_ * matmul_param_->a_stride_,
src_out + batch_index * matmul_param_->col_ * matmul_param_->c_stride_);
Pre1x1Trans(src_in + batch_index * conv_param_->input_h_ * conv_param_->input_w_ * conv_param_->input_channel_,
src_out + batch_index * matmul_param_->row_ * matmul_param_->col_);
int error_code = LiteBackendParallelLaunch(Convolution1x1StrassenRun, this, thread_hw_count_);
int error_code = LiteBackendParallelLaunch(Convolution1x1Run, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "conv1x1 strassen error error_code[" << error_code << "]";
return RET_ERROR;
}
error_code = LiteBackendParallelLaunch(Convolution1x1PostFuncRun, this, thread_oc4_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "conv1x1 post function error error_code[" << error_code << "]";
return RET_ERROR;
}
Row8x8Major2RowMajor(pack_output_, output_ptr_, matmul_param_->row_, matmul_param_->col_);
}
return RET_OK;
}

View File

@ -17,6 +17,7 @@
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_1X1_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_1X1_H_
#include <float.h>
#include <vector>
#include "src/lite_kernel.h"
#include "include/errorcode.h"
@ -26,21 +27,24 @@
#include "src/runtime/kernel/arm/base/layout_transform.h"
#include "src/runtime/kernel/arm/opclib/fp32/conv.h"
#include "src/runtime/kernel/arm/opclib/fp32/common_func.h"
#include "src/runtime/kernel/arm/opclib/matmul.h"
#include "src/runtime/kernel/arm/opclib/fp32/matmul.h"
namespace mindspore::kernel {
class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel {
public:
Convolution1x1CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx)
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {}
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {
matmul_param_ = new MatMulParameter();
}
~Convolution1x1CPUKernel();
int Init() override;
int Run() override;
int ReSize() override;
public:
int DoStrassen(int task_id);
int DoPostFunc(int task_id);
int DoConv1x1(int task_id);
private:
int InitConv1x1Param();
@ -49,20 +53,15 @@ class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel {
void Pre1x1Trans(float *src_input, float *src_output);
private:
StrassenMatMulParameter *matmul_param_ = nullptr;
MatMulParameter *matmul_param_ = nullptr;
bool pre_trans_input_ = false;
int thread_count_ = 0;
int thread_hw_count_ = 0;
int thread_hw_stride_ = 0;
int thread_oc4_count_ = 0;
int thread_oc_stride_ = 0;
int thread_stride_ = 0;
float *weight_ptr_ = nullptr;
float *tmp_ptr_ = nullptr;
float *c4_input_ = nullptr;
float *c4_output_ = nullptr;
float *pack_input_ = nullptr;
float *pack_output_ = nullptr;
float *input_ptr_ = nullptr;
float *output_ptr_ = nullptr;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_1X1_H_

View File

@ -30,27 +30,38 @@ DeConvolutionCPUKernel::~DeConvolutionCPUKernel() {
free(weight_ptr_);
weight_ptr_ = nullptr;
}
if (tmp_output_ != nullptr) {
free(tmp_output_);
tmp_output_ = nullptr;
}
if (tmp_buffer_ != nullptr) {
free(tmp_buffer_);
tmp_buffer_ = nullptr;
}
if (c4_input_ != nullptr) {
free(c4_input_);
c4_input_ = nullptr;
if (pack_input_ != nullptr) {
free(pack_input_);
pack_input_ = nullptr;
}
if (c4_output_ != nullptr) {
free(c4_output_);
c4_output_ = nullptr;
if (pack_output_ != nullptr) {
free(pack_output_);
pack_output_ = nullptr;
}
return;
}
int DeConvolutionCPUKernel::ReSize() { return 0; }
int DeConvolutionCPUKernel::ReSize() {
if (tmp_buffer_ != nullptr) {
free(tmp_buffer_);
tmp_buffer_ = nullptr;
}
if (pack_input_ != nullptr) {
free(pack_input_);
pack_input_ = nullptr;
}
if (pack_output_ != nullptr) {
free(pack_output_);
pack_output_ = nullptr;
}
InitParam();
return RET_OK;
}
int DeConvolutionCPUKernel::InitWeightBias() {
if (inputs_.size() == 3) {
@ -65,61 +76,51 @@ int DeConvolutionCPUKernel::InitWeightBias() {
bias_data_ = nullptr;
}
size_t weight_pack_size = conv_param_->kernel_w_ * conv_param_->kernel_h_ *
UP_ROUND(conv_param_->output_channel_, C4NUM) *
UP_ROUND(conv_param_->input_channel_, C4NUM) * sizeof(float);
size_t weight_pack_size = conv_param_->input_channel_ * conv_param_->kernel_w_ * conv_param_->kernel_h_ *
UP_ROUND(conv_param_->output_channel_, C8NUM) * sizeof(float);
weight_ptr_ = reinterpret_cast<float *>(malloc(weight_pack_size));
if (weight_ptr_ == nullptr) {
MS_LOG(ERROR) << "deconv malloc weight_ptr_ error!";
return RET_ERROR;
}
memset(weight_ptr_, 0, weight_pack_size);
PackDeConvWeightFp32(reinterpret_cast<float *>(inputs_[1]->Data()), weight_ptr_, conv_param_->input_channel_,
conv_param_->output_channel_, conv_param_->kernel_w_ * conv_param_->kernel_h_);
PackNHWCToC8HWN8Fp32(reinterpret_cast<float *>(inputs_[1]->Data()), weight_ptr_, conv_param_->input_channel_,
kernel_plane_, conv_param_->output_channel_);
return RET_OK;
}
int DeConvolutionCPUKernel::InitParam() {
matmul_param_ = new StrassenMatMulParameter();
matmul_param_->row_ = conv_param_->input_h_ * conv_param_->input_w_;
matmul_param_->deep_ = UP_DIV(conv_param_->input_channel_, C4NUM);
matmul_param_->col_ = UP_DIV(conv_param_->output_channel_, 4) * conv_param_->kernel_w_ * conv_param_->kernel_h_;
matmul_param_->a_stride_ = matmul_param_->row_ * C4NUM;
matmul_param_->b_stride_ = matmul_param_->deep_ * C4NUM * C4NUM;
matmul_param_->c_stride_ = matmul_param_->row_ * C4NUM;
input_plane_ = conv_param_->input_h_ * conv_param_->input_w_;
kernel_plane_ = conv_param_->kernel_w_ * conv_param_->kernel_h_;
output_plane_ = conv_param_->output_h_ * conv_param_->output_w_;
thread_hw_count_ = MSMIN(opParameter->thread_num_, matmul_param_->row_);
thread_hw_stride_ = UP_DIV(matmul_param_->row_, thread_hw_count_);
matmul_param_->row_ = input_plane_;
matmul_param_->deep_ = conv_param_->input_channel_;
matmul_param_->col_ = conv_param_->output_channel_ * kernel_plane_;
matmul_param_->row_8_ = UP_ROUND(matmul_param_->row_, C8NUM);
matmul_param_->col_8_ = UP_ROUND(conv_param_->output_channel_, C8NUM) * kernel_plane_;
thread_co4_count_ = MSMIN(opParameter->thread_num_, UP_DIV(conv_param_->output_channel_, C4NUM));
thread_co_stride_ = UP_DIV(UP_DIV(conv_param_->output_channel_, C4NUM), thread_co4_count_) * C4NUM;
thread_count_ = MSMIN(opParameter->thread_num_, UP_DIV(conv_param_->output_channel_, C8NUM));
thread_stride_ = UP_DIV(UP_DIV(conv_param_->output_channel_, C8NUM), thread_count_);
tmp_buffer_ =
reinterpret_cast<float *>(malloc(matmul_param_->a_stride_ * matmul_param_->deep_ * C4NUM * sizeof(float)));
pack_input_ = reinterpret_cast<float *>(malloc(matmul_param_->row_8_ * matmul_param_->deep_ * sizeof(float)));
if (pack_input_ == nullptr) {
MS_LOG(ERROR) << "deconv Malloc pack_input_ error!";
return RET_ERROR;
}
pack_output_ =
reinterpret_cast<float *>(malloc(UP_ROUND(conv_param_->output_channel_, C8NUM) * output_plane_ * sizeof(float)));
if (pack_output_ == nullptr) {
MS_LOG(ERROR) << "deconv Malloc pack_output_ error!";
return RET_NULL_PTR;
}
tmp_buffer_ = reinterpret_cast<float *>(malloc(matmul_param_->row_8_ * matmul_param_->col_8_ * sizeof(float)));
if (tmp_buffer_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc tmp_buffer_ error!";
return RET_ERROR;
}
tmp_output_ = reinterpret_cast<float *>(malloc(matmul_param_->row_ * matmul_param_->col_ * C4NUM * sizeof(float)));
if (tmp_output_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc tmp_output_ error!";
return RET_ERROR;
}
c4_input_ =
reinterpret_cast<float *>(malloc(inputs_[0]->ElementsC4Num() / conv_param_->input_batch_ * sizeof(float)));
if (c4_input_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc c4_input_ error!";
return RET_NULL_PTR;
}
c4_output_ =
reinterpret_cast<float *>(malloc(outputs_[0]->ElementsC4Num() / conv_param_->output_batch_ * sizeof(float)));
if (c4_output_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc c4_output_ error!";
return RET_NULL_PTR;
}
return RET_OK;
}
@ -132,6 +133,7 @@ int DeConvFp32Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
}
return RET_OK;
}
int DeConvFp32PostRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
auto deconv = reinterpret_cast<DeConvolutionCPUKernel *>(cdata);
auto error_code = deconv->DoPostFunc(task_id);
@ -141,51 +143,39 @@ int DeConvFp32PostRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
}
return RET_OK;
}
int DeConvolutionCPUKernel::DoDeconv(int task_id) {
matmul_param_->row_ = MSMIN(thread_hw_stride_, matmul_param_->row_ - task_id * thread_hw_stride_);
if (matmul_param_->row_ <= 0) {
int oc = MSMIN(thread_stride_, UP_DIV(conv_param_->output_channel_, C8NUM) - task_id * thread_stride_);
if (oc <= 0) {
return RET_OK;
}
int error_code = DeConvFp32(c4_input_ + task_id * thread_hw_stride_ * C4NUM, weight_ptr_,
tmp_output_ + task_id * thread_hw_stride_ * C4NUM,
tmp_buffer_ + task_id * thread_hw_stride_ * matmul_param_->deep_ * C4NUM, *matmul_param_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "DeConvFp32 error! error code: " << error_code;
return error_code;
}
MatMul(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_,
tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_8_, nullptr, ActType_No,
matmul_param_->deep_, matmul_param_->row_8_, oc * C8NUM * kernel_plane_);
matmul_param_->row_ = conv_param_->input_h_ * conv_param_->input_w_;
return RET_OK;
}
int DeConvolutionCPUKernel::DoPostFunc(int task_id) {
int input_plane = conv_param_->input_h_ * conv_param_->input_w_;
int kernel_plane = conv_param_->kernel_w_ * conv_param_->kernel_h_;
int output_plane = conv_param_->output_h_ * conv_param_->output_w_;
int cur_oc = MSMIN(thread_co_stride_, conv_param_->output_channel_ - task_id * thread_co_stride_);
if (cur_oc <= 0) {
int oc = MSMIN(thread_stride_ * C8NUM, conv_param_->output_channel_ - task_id * thread_stride_ * C8NUM);
if (oc <= 0) {
return RET_OK;
}
float *cur_bias =
(bias_data_ == nullptr) ? nullptr : reinterpret_cast<float *>(bias_data_) + thread_co_stride_ * task_id;
float *bias =
(bias_data_ == nullptr) ? nullptr : reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id * C8NUM;
DeConvPostFp32(tmp_output_ + thread_co_stride_ * task_id * input_plane * kernel_plane,
c4_output_ + thread_co_stride_ * task_id * output_plane, output_ptr_ + thread_co_stride_ * task_id,
cur_bias, cur_oc, input_plane, kernel_plane, output_plane, conv_param_);
DeConvPostFp32C8x8(tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_8_,
pack_output_ + task_id * thread_stride_ * C8NUM * output_plane_, bias,
output_ptr_ + task_id * thread_stride_ * C8NUM, oc, conv_param_);
return RET_OK;
}
int DeConvolutionCPUKernel::Init() {
int error_code = ConvolutionBaseCPUKernel::Init();
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Conv base init error!";
return error_code;
}
ConvolutionBaseCPUKernel::Init();
error_code = InitParam();
int error_code = InitParam();
if (error_code != RET_OK) {
MS_LOG(ERROR) << "deconv InitParam error!";
return error_code;
@ -204,20 +194,18 @@ int DeConvolutionCPUKernel::Run() {
float *src_out = reinterpret_cast<float *>(outputs_[0]->Data());
for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) {
input_ptr_ = src_in + batch_index * conv_param_->input_w_ * conv_param_->input_h_ * conv_param_->input_channel_;
output_ptr_ =
src_out + batch_index * conv_param_->output_h_ * conv_param_->output_w_ * conv_param_->output_channel_;
input_ptr_ = src_in + batch_index * input_plane_ * conv_param_->input_channel_;
output_ptr_ = src_out + batch_index * output_plane_ * conv_param_->output_channel_;
PackNHWCToNC4HW4Fp32(input_ptr_, c4_input_, 1, conv_param_->input_h_ * conv_param_->input_w_,
conv_param_->input_channel_);
RowMajor2Col8Major(input_ptr_, pack_input_, input_plane_, conv_param_->input_channel_);
int error_code = LiteBackendParallelLaunch(DeConvFp32Run, this, thread_hw_count_);
int error_code = LiteBackendParallelLaunch(DeConvFp32Run, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "deconv fp32 run error! error_code[" << error_code << "]";
return RET_ERROR;
}
error_code = LiteBackendParallelLaunch(DeConvFp32PostRun, this, thread_co4_count_);
error_code = LiteBackendParallelLaunch(DeConvFp32PostRun, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "deconv fp32 postrun error! error_code[" << error_code << "]";
return RET_ERROR;

View File

@ -17,6 +17,7 @@
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DECONVOLUTION_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DECONVOLUTION_H_
#include <float.h>
#include <vector>
#include "src/lite_kernel.h"
#include "src/kernel_registry.h"
@ -24,13 +25,16 @@
#include "schema/model_generated.h"
#include "src/runtime/kernel/arm/base/convolution_base.h"
#include "src/runtime/kernel/arm/opclib/fp32/deconv.h"
#include "src/runtime/kernel/arm/opclib/fp32/matmul.h"
namespace mindspore::kernel {
class DeConvolutionCPUKernel : public ConvolutionBaseCPUKernel {
public:
DeConvolutionCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx)
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {}
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {
matmul_param_ = new MatMulParameter();
}
~DeConvolutionCPUKernel() override;
int Init() override;
int Run() override;
@ -45,19 +49,18 @@ class DeConvolutionCPUKernel : public ConvolutionBaseCPUKernel {
int InitWeightBias();
private:
StrassenMatMulParameter *matmul_param_;
int thread_hw_count_;
int thread_hw_stride_;
int thread_co4_count_;
int thread_co_stride_;
MatMulParameter *matmul_param_;
int input_plane_;
int kernel_plane_;
int output_plane_;
int thread_count_;
int thread_stride_;
float *weight_ptr_;
float *pack_input_;
float *pack_output_;
float *tmp_buffer_;
float *tmp_output_;
float *c4_input_;
float *c4_output_;
float *input_ptr_;
float *output_ptr_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DECONVOLUTION_H_

View File

@ -99,8 +99,8 @@ int FullconnectionCPUKernel::DoMatmul(int task_id) {
MatMul(a_c8_ptr_, b_r8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->deep_,
c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->row_8_,
bias_ptr_ + task_id * thread_stride_ * C8NUM, fc_param_->maxf_, fc_param_->minf_, fc_param_->deep_,
fc_param_->row_8_, cur_oc * 8);
bias_ptr_ + task_id * thread_stride_ * C8NUM, fc_param_->act_type_, fc_param_->deep_, fc_param_->row_8_,
cur_oc * 8);
return RET_OK;
}

View File

@ -82,9 +82,9 @@ int FullconnectionInt8CPUKernel::Init() {
double real_multiplier = quant_params_.input.scale_ * quant_params_.weight.scale_ / quant_params_.output.scale_;
QuantizeRoundParameter(real_multiplier, &quant_params_.quant_multiplier, &quant_params_.left_shift,
&quant_params_.right_shift);
CalculateActivationRangeQuantized(fc_param_->maxf_, fc_param_->minf_, quant_params_.output.scale_,
quant_params_.output.zp_, &quant_params_.out_act_max, &quant_params_.out_act_min);
CalculateActivationRangeQuantized(fc_param_->act_type_ == ActType_Relu, fc_param_->act_type_ == ActType_Relu6,
quant_params_.output.zp_, quant_params_.output.scale_, &quant_params_.out_act_max,
&quant_params_.out_act_min);
return RET_OK;
}

View File

@ -63,23 +63,29 @@ void MatrixMultiAdd(float *c11, float *c12, float *c21, float *c22, float *x_ptr
return;
}
void PostConvFuncFp32(const float *c4_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel,
size_t plane_size, size_t stride, bool is_relu, bool is_relu6) {
#ifndef ENABLE_ARM64
void PostConvFuncComm(const float *src_ptr_, float *out_ptr, const float *bias_ptr, size_t output_channel,
size_t plane_size, size_t stride, bool is_relu, bool is_relu6, int size) {
for (int oc = 0; oc < output_channel; oc++) {
int oc4div = oc / 4, oc4mod = oc % 4;
int oc_div = oc / size, oc_mod = oc % size;
for (int hw = 0; hw < plane_size; hw++) {
int src_index = oc4div * 4 * plane_size + hw * 4 + oc4mod;
int src_index = oc_div * size * plane_size + hw * size + oc_mod;
int dst_index = hw * stride + oc;
float value = c4_out_ptr[src_index];
float value = src_ptr_[src_index];
if (bias_ptr != nullptr) {
value = value + bias_ptr[oc];
}
value = (is_relu) ? (MSMAX(0, value)) : (value);
value = (is_relu6) ? (MSMIN(6, MSMAX(0, value))) : (value);
value = (is_relu || is_relu6) ? (MSMAX(0.f, value)) : (value);
value = (is_relu6) ? (MSMIN(6.f, value)) : (value);
out_ptr[dst_index] = value;
}
}
return;
}
void PostConvFuncFp32C4(const float *c4_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel,
size_t plane_size, size_t stride, bool is_relu, bool is_relu6) {
#ifndef ENABLE_ARM64
PostConvFuncComm(c4_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, stride, is_relu, is_relu6, C4NUM);
#else
if (bias_ptr != nullptr) {
if (is_relu) {
@ -102,3 +108,8 @@ void PostConvFuncFp32(const float *c4_out_ptr, float *out_ptr, const float *bias
return;
}
void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel,
size_t plane_size, size_t stride, bool is_relu, bool is_relu6) {
PostConvFuncComm(c8_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, stride, is_relu, is_relu6, C8NUM);
return;
}

View File

@ -27,8 +27,10 @@
extern "C" {
#endif
void PostConvFuncFp32(const float *c4_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel,
size_t plane_size, size_t stride, bool is_relu, bool is_relu6);
void PostConvFuncFp32C4(const float *c4_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel,
size_t plane_size, size_t stride, bool is_relu, bool is_relu6);
void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel,
size_t plane_size, size_t stride, bool is_relu, bool is_relu6);
void MatrixAdd(const float *a_ptr, const float *b_ptr, float *dst, size_t a_stride, size_t b_stride, size_t c_stride,
size_t row, size_t col);
void MatrixSub(const float *a_ptr, const float *b_ptr, float *dst, size_t a_stride, size_t b_stride, size_t c_stride,
@ -60,4 +62,3 @@ void DeconvDwFp32Center(float *dst, const float *src, const float *weight, size_
#endif
#endif /* MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_COMMON_FUNC_H_ */

View File

@ -38,8 +38,52 @@ int DeConvFp32(const float *input, const float *weight, float *output, float *tm
return StrassenMatmul(input, weight, output, &matmul_param, FP32_STRASSEN_MAX_RECURSION, 0, tmp_buffer);
}
int DeConvPostFp32(const float *src, float *tmp_c4, float *dst, const float *bias, int output_channel, int input_plane,
int kernel_plane, int output_plane, ConvParameter *conv_param) {
int DeConvPostFp32C8x8(const float *src, float *tmp, const float *bias, float *dst, int output_channel,
ConvParameter *conv_param) {
/* row8x8-major(ih*iw x oc*kh*kw) -> row8-major(oh*ow x oc) */
size_t input_plane = conv_param->input_w_ * conv_param->input_h_;
size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_;
size_t output_plane = conv_param->output_w_ * conv_param->output_h_;
int oc8 = UP_DIV(output_channel, C8NUM);
int in_plane8 = UP_ROUND(input_plane, C8NUM);
for (int c = 0; c < oc8; c++) {
float *dst_ptr = tmp + c * output_plane * C8NUM;
const float *src_ptr = src + c * in_plane8 * kernel_plane * C8NUM;
memset(dst_ptr, 0, output_plane * C8NUM * sizeof(int32_t));
for (int ih = 0; ih < conv_param->input_h_; ih++) {
for (int iw = 0; iw < conv_param->input_w_; iw++) {
int oh = ih * conv_param->stride_h_ - conv_param->pad_h_;
int ow = iw * conv_param->stride_w_ - conv_param->pad_w_;
int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));
int kw_start = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_));
int kw_end = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_));
for (int kh = kh_start; kh < kh_end; kh++) {
for (int kw = kw_start; kw < kw_end; kw++) {
int src_index = ih * conv_param->input_w_ * C8NUM + iw * C8NUM +
kh * in_plane8 * conv_param->kernel_w_ * C8NUM + kw * in_plane8 * C8NUM;
int dst_index = oh * conv_param->output_w_ * C8NUM + ow * C8NUM +
kh * conv_param->dilation_h_ * conv_param->output_w_ * C8NUM +
kw * conv_param->dilation_w_ * C8NUM;
for (int i = 0; i < C8NUM; i++) {
dst_ptr[dst_index + i] += src_ptr[src_index + i];
}
} /*kw*/
} /*kh*/
} /*iw*/
} /*ih*/
} /*oc8*/
PostConvFuncFp32C8(tmp, dst, bias, output_channel, output_plane, conv_param->output_channel_, conv_param->is_relu_,
conv_param->is_relu6_);
return OPCLIB_OK;
}
int DeConvPostFp32C4(const float *src, float *tmp_c4, float *dst, const float *bias, int output_channel,
int input_plane, int kernel_plane, int output_plane, ConvParameter *conv_param) {
int oc4 = UP_DIV(output_channel, C4NUM);
for (int c = 0; c < oc4; c++) {
float *dst_ptr = tmp_c4 + c * output_plane * C4NUM;
@ -71,8 +115,7 @@ int DeConvPostFp32(const float *src, float *tmp_c4, float *dst, const float *bia
} /*ih*/
} /*oc4*/
PostConvFuncFp32(tmp_c4, dst, bias, output_channel, output_plane, conv_param->output_channel_, conv_param->is_relu_,
conv_param->is_relu6_);
PostConvFuncFp32C4(tmp_c4, dst, bias, output_channel, output_plane, conv_param->output_channel_, conv_param->is_relu_,
conv_param->is_relu6_);
return OPCLIB_OK;
}

View File

@ -26,8 +26,9 @@ void PackDeConvWeightFp32(const float *weight, float *dst, int input_channel, in
int DeConvFp32(const float *input, const float *weight, float *output, float *tmp_buffer,
StrassenMatMulParameter matmul_param);
int DeConvPostFp32(const float *src, float *tmp_c4, float *dst, const float *bias, int output_channel, int input_plane,
int kernel_plane, int output_plane, ConvParameter *conv_param);
int DeConvPostFp32C4(const float *src, float *tmp_c4, float *dst, const float *bias, int output_channel,
int input_plane, int kernel_plane, int output_plane, ConvParameter *conv_param);
int DeConvPostFp32C8x8(const float *src, float *tmp_out, const float *bias, float *dst, int output_channel,
ConvParameter *conv_param);
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_DECONV_H_

View File

@ -48,10 +48,11 @@ void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, int row, int col) {
dst_ptr[r * col + c] = src_ptr[cd8 * row8 * 8 + r * 8 + cm8];
}
}
return;
}
void MatMul8x8(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int deep,
int row_8_, int col_8_) {
void MatMul8x8(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row_8_,
int col_8_) {
/* col8-major * row8-major => col8x8-major */
for (int row = 0; row < row_8_; row++) {
for (int col = 0; col < col_8_; col++) {
@ -64,19 +65,25 @@ void MatMul8x8(const float *a, const float *b, float *c, const float *bias, floa
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
value = value + a[ai] * b[bi];
}
value += bias[col];
value = MSMIN(maxf, value);
value = MSMAX(minf, value);
if (bias != nullptr) {
value += bias[col];
}
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
if (act_type != ActType_No) value = MSMAX(0.0f, value);
c[ci] = value;
}
}
return;
}
void MatMul(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int deep, int row_8_,
void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row_8_,
int col_8_) {
#ifdef __aarch64__
float minf = (act_type == ActType_No) ? FLT_MIN : 0.f;
float maxf = (act_type == ActType_Relu6) ? 6.0f : FLT_MAX;
MatMulFloatNeon64(a, b, c, bias, maxf, minf, deep, row_8_, col_8_);
#else
MatMul8x8(a, b, c, bias, maxf, minf, deep, row_8_, col_8_);
MatMul8x8(a, b, c, bias, act_type, deep, row_8_, col_8_);
#endif
return;
}

View File

@ -17,12 +17,12 @@
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_MATMUL_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_MATMUL_H_
#include <float.h>
#include "src/runtime/kernel/arm/opclib/errorcode.h"
#include "src/runtime/kernel/arm/opclib/op_base.h"
#include "src/runtime/kernel/arm/opclib/matmul.h"
void MatMul(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int depth, int row,
int col);
void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, int col);
void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, int row, int col);
void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, int row, int col);

View File

@ -25,15 +25,18 @@ int DeConvInt8(const int8_t *input, const int8_t *weight, int32_t *output, size_
int DeConvPostInt8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel,
ConvParameter *conv_param) {
int oc8 = UP_DIV(output_channel, C8NUM);
/* row8x8-major(ih*iw x oc*kh*kw) -> row8x8-major(oh*ow x oc) */
size_t input_plane = conv_param->input_w_ * conv_param->input_h_;
size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_;
size_t output_plane = conv_param->output_w_ * conv_param->output_h_;
int oc8 = UP_DIV(output_channel, C8NUM);
int in_plane8 = UP_ROUND(input_plane, 8);
int out_plane8 = UP_ROUND(output_plane, 8);
for (int c = 0; c < oc8; c++) {
int32_t *dst_ptr = tmp + c * output_plane * C8NUM;
const int32_t *src_ptr = src + c * input_plane * kernel_plane * C8NUM;
memset(dst_ptr, 0, output_plane * C8NUM * sizeof(int32_t));
int32_t *dst_ptr = tmp + c * out_plane8 * C8NUM;
const int32_t *src_ptr = src + c * in_plane8 * kernel_plane * C8NUM;
memset(dst_ptr, 0, out_plane8 * C8NUM * sizeof(int32_t));
for (int ih = 0; ih < conv_param->input_h_; ih++) {
for (int iw = 0; iw < conv_param->input_w_; iw++) {
@ -60,7 +63,7 @@ int DeConvPostInt8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t
} /*ih*/
} /*oc8*/
PostFuncInt8(tmp, bias, out, output_channel, output_plane, UP_ROUND(output_plane, 8),
PostFuncInt8(tmp, bias, out, output_channel, output_plane, out_plane8,
conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0],
conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.quant_args_[2][0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]);

View File

@ -19,6 +19,8 @@
#include "src/runtime/kernel/arm/opclib/op_base.h"
enum ActType { ActType_No, ActType_Relu, ActType_Relu6 };
struct MatMulParameter {
OpParameter op_parameter_;
int row_;
@ -26,12 +28,10 @@ struct MatMulParameter {
int row_8_;
int col_8_;
int deep_;
float minf_;
float maxf_;
bool has_bias_;
bool a_transpose_; /* false : row-major */
bool b_transpose_; /* true : col-major */
ActType act_type_;
};
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_MATMUL_H_

View File

@ -150,23 +150,21 @@ void PackWeightInt8Opt(int8_t *weight_data, ConvParameter *conv_param, int8_t *p
}
void Conv1x1InputPackFp32(const float *src, float *dst, ConvParameter *conv_param) {
for (int c = 0; c < UP_DIV(conv_param->input_channel_, C4NUM); c++) {
const float *src_c_ptr = src + c * conv_param->input_h_ * conv_param->input_w_ * C4NUM;
float *dst_c_ptr = dst + c * conv_param->output_h_ * conv_param->output_w_ * C4NUM;
for (int dst_h = 0; dst_h < conv_param->output_h_; dst_h++) {
int src_h = dst_h * conv_param->stride_h_ - conv_param->pad_h_;
if (src_h < 0 || src_h >= conv_param->input_h_) {
/* support nhwc */
for (int dst_h = 0; dst_h < conv_param->output_h_; dst_h++) {
int src_h = dst_h * conv_param->stride_h_ - conv_param->pad_h_;
if (src_h < 0 || src_h >= conv_param->input_h_) {
continue;
}
const float *src_h_ptr = src + src_h * conv_param->input_w_ * conv_param->input_channel_;
float *dst_h_ptr = dst + dst_h * conv_param->output_w_ * conv_param->input_channel_;
for (int dst_w = 0; dst_w < conv_param->output_w_; dst_w++) {
int src_w = dst_w * conv_param->stride_w_ - conv_param->pad_w_;
if (src_w < 0 || src_w >= conv_param->input_w_) {
continue;
}
const float *src_h_ptr = src_c_ptr + src_h * conv_param->input_w_ * C4NUM;
float *dst_h_ptr = dst_c_ptr + dst_h * conv_param->output_w_ * C4NUM;
for (int dst_w = 0; dst_w < conv_param->output_w_; dst_w++) {
int src_w = dst_w * conv_param->stride_w_ - conv_param->pad_w_;
if (src_w < 0 || src_w >= conv_param->input_w_) {
continue;
}
memcpy(dst_h_ptr + dst_w * C4NUM, src_h_ptr + src_w * C4NUM, C4NUM * sizeof(float));
}
memcpy(dst_h_ptr + dst_w * conv_param->input_channel_, src_h_ptr + src_w * conv_param->input_channel_,
conv_param->input_channel_ * sizeof(float));
}
}
return;
@ -572,6 +570,21 @@ void PackNC4HW4ToNCHWFp32(const void *src, void *dst, int batch, int plane, int
}
}
void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel) {
for (int n = 0; n < batch; n++) {
for (int hw = 0; hw < plane; hw++) {
for (int c = 0; c < channel; c++) {
int c8div = c / C8NUM;
int c8mod = c % C8NUM;
int src_index = n * plane * channel + hw * channel + c;
int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod;
((float *)dst)[dst_index] = ((float *)src)[src_index];
}
}
}
return;
}
void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel) {
int c4 = UP_DIV(channel, C4NUM);
int nhwc4_batch_unit_offset = c4 * C4NUM * plane;

View File

@ -69,6 +69,8 @@ void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int
void PackNC4HW4ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);

View File

@ -21,6 +21,7 @@
#include <math.h>
#include <stdlib.h>
#include <limits.h>
#include <limits>
struct QuantArg {
double scale_;
@ -112,13 +113,21 @@ inline uint8_t QuantizeToUint8(float real_value, float scale, int32_t zp) { retu
inline int32_t QuantizeToInt8(float real_value, float scale, int32_t zp) { return round(real_value / scale + zp); }
inline void CalculateActivationRangeQuantized(float fmax, float fmin, float scale, int zero_point, int *imax,
int *imin) {
int8_t qmin = (int8_t)CHAR_MIN;
int8_t qmax = (int8_t)CHAR_MAX;
int8_t qfmin = QuantizeToInt8(fmin, scale, zero_point);
int8_t qfmax = QuantizeToInt8(fmax, scale, zero_point);
*imin = qmin < qfmin ? qmin : qfmin;
*imax = qmax > qfmax ? qmax : qfmax;
inline void CalculateActivationRangeQuantized(bool is_relu, bool is_relu6, int32_t zp, int32_t scale, int *mini,
int *maxi) {
int32_t min = std::numeric_limits<int8_t>::min();
int32_t max = std::numeric_limits<int8_t>::max();
int32_t quantized_zero = QuantizeToInt8(0, scale, zp);
int32_t quantized_six = QuantizeToInt8(6, scale, zp);
if (is_relu) {
min = min > quantized_zero ? min : quantized_zero;
} else if (is_relu6) {
min = min > quantized_zero ? min : quantized_zero;
max = max < quantized_six ? max : quantized_six;
} else {
// do nothing
}
*mini = min;
*maxi = max;
}
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_QUANTIZATION_QUANTIZE_H_

View File

@ -6,5 +6,15 @@ BUILD_DIR=${CUR_DIR}/../build
mkdir -pv ${CUR_DIR}/do_test
cd ${CUR_DIR}/do_test
cp ${BUILD_DIR}/test/lite-test ./
cp -r ${CUR_DIR}/ut/src/runtime/kernel/arm/test_data/* ./
./lite-test --gtest_filter="*TestHebing*"
./lite-test --gtest_filter=TestFcFp32*
./lite-test --gtest_filter=TestConv1x1Fp32*
./lite-test --gtest_filter=TestStrassenFp32*
./lite-test --gtest_filter=TestDeConvolutionFp32*
./lite-test --gtest_filter=TestPadInt8*

View File

@ -0,0 +1,395 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <sys/time.h>
#include <iostream>
#include <memory>
#include "utils/log_adapter.h"
#include "common/common_test.h"
#include "src/common/file_utils.h"
#include "src/runtime/kernel/arm/fp32/convolution_1x1.h"
#include "src/runtime/kernel/arm/opclib/matmul.h"
#include "src/runtime/kernel/arm/opclib/strassen_matmul.h"
namespace mindspore {
using mindspore::lite::tensor::Tensor;
class TestConv1x1Fp32 : public mindspore::Common {
public:
TestConv1x1Fp32() {}
};
TEST_F(TestConv1x1Fp32, Input1x1PrePack1) {
auto conv_param = new ConvParameter();
float in[] = {-0.59, -0.63, -7.26, -0.64, -6.403, 4.87, 9.612, 9.36, 12.84, -0.838, 6.588, 2.02, 13.756,
15.92, 16.0, -7.82, 9.53, 1.77, 10.521, 13.45, 17.991, 17.063, 4.6859, 13.57, -6.31, 5.27,
7.54, -7.418, 15.12, 0.6195, 1.5475, -5.925, -7.59, 18.13, 15.8, 19.86, -7.766, 13.25, 7.141,
-0.34, 16.254, -5.78, 16.13, -7.1, 6.259, 10.771, -5.54, 10.477, 9.2366, 12.258, -9.86, -8.29,
-4.9, 18.14, -5.400, 0.829, 7.4575, 12.075, 13.734, 16.51, -9.82, -4.9, 18.44, -0.808, 8.066,
6.914, 2.5098, 10.985, 16.96, 1.721, -1.0, 2.096, 9.2553, 8.635, 9.2136, 13.558, 7.7505, -0.55,
15.68, -7.3, 0.429, -0.560, 17.98, 19.068, 9.2764, 17.939, -6.51, -2.04, 7.29, -0.87, 10.311,
-6.74, -6.424, 18.708, -0.368, 9.725, 9.129, 6.99, 3.11, -1.573, -8.25, 10.427, 17.427, -9.739,
17.32, 6.076, -3.5, 7.43, -2.659, -0.89, -9.157, 1.9951, -3.463, 15.22, 13.99, 4.39, 18.12};
float correct[] = {0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 15.12, -7.59, -7.766, 0.000,
0.000, 0.429, 9.2764, 7.29, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000};
conv_param->input_h_ = 9;
conv_param->input_w_ = 13;
conv_param->input_channel_ = 1;
conv_param->output_h_ = 4;
conv_param->output_w_ = 5;
conv_param->stride_h_ = conv_param->stride_w_ = 4;
conv_param->pad_h_ = conv_param->pad_w_ = 2;
float out[20] = {0};
Conv1x1InputPackFp32(in, out, conv_param);
EXPECT_EQ(0, lite::CompareOutputData(out, correct, 20));
delete conv_param;
}
TEST_F(TestConv1x1Fp32, Input1x1PrePack2) {
auto conv_param = new ConvParameter();
float in[] = {
12.755477, 7.647509, 14.670943, -8.03628, -1.815172, 7.7517915, 5.6838546, 0.9693578, 10.86119, 10.960915,
17.758, -4.800611, -8.743361, 1.6797531, -0.234721, 7.7575417, 10.19116, 11.744166, -2.674233, 8.977257,
1.5364298, 14.600166, 16.625568, -4.820712, 10.050005, 4.114301, 10.436717, -7.443196, -2.669484, 5.3399734,
7.5060234, 12.705402, -2.203446, 19.582493, 8.716431, 11.463841, 2.1704009, -7.740846, 0.6420606, 15.4524,
1.9975507, -4.6742086, -0.425350, 7.120687, -9.663703, 18.799034, -4.425679, 10.846515, -1.993019, 0.2714671,
-8.511215, 16.797249, 18.438688, 8.391737, 15.632475, 16.98368, -5.901906, -2.718238, -3.131561, -3.707477,
-8.04332, 13.010143, 3.187699, 7.6656003, 9.344805, 2.100789, -7.123898, 10.088698, 7.8578715, -8.320831,
6.821173, -2.263130, -2.886815, 2.285673, 10.664816, -4.747543, -4.9607406, 1.0546302, 15.628643, 1.7381196,
18.267065, 11.504781, -0.193673, 16.431538, 8.011203, -3.3506372, 16.546675, -3.983052, 4.8116174, -9.49816,
11.714877, 12.401133, -3.799531, 5.109032, 11.657709, 1.9226302, 0.9720376, 14.517606, 7.712793, 17.820406,
17.644344, 15.314725, 17.884249, -3.6718662, -2.053803, 10.629432, 16.67133, -3.929358, 3.3747706, 8.818307,
-0.371532, 18.14205, 5.9272094, 12.691162, 6.816437, 8.310599, 17.566565, 16.581955, -7.433713, 2.5550082,
9.1433325, -2.9258926, 5.7442937, -2.9434314, -9.864248, -0.122141, 11.5717945, -4.174809, -6.192147, 8.390994,
-7.4617224, 17.419308, 7.0560303, 11.58972, 17.671894, 6.2352304, 13.778206, 3.4766717, -6.687946, -7.887233,
-1.150991, -3.1441534, 17.288366, 13.669407, -4.997481, -6.147624, -5.6006193, -8.15764, 9.595266, 8.296087,
-0.9590447, -3.6464965, -8.155689, 4.8459644, 19.75259, 5.5307946, -6.934994, -9.928046, 4.02548, -9.45412,
13.605555, 10.22008, -3.067481, 8.114803, 2.4563003, 0.4125615, 6.076172, -1.875376, 19.553644, -9.809106,
17.235031, -4.222316, -9.534478, 18.639902, 1.7095382, 18.821035, -8.177748, -2.9353676, 2.064462, 12.190292,
-1.475221, -1.842325, -3.664825, 10.538533, -4.255415, 3.4860964, 11.418711, -2.348281, -4.527373, 19.534836};
float correct[] = {12.755477, -8.03628, 5.6838546, 10.960915, 7.5060234, 19.582493, 2.1704009,
15.4524, -8.04332, 7.6656003, -7.123898, -8.320831, 11.714877, 5.109032,
0.9720376, 17.820406, 9.1433325, -2.9434314, 11.5717945, 8.390994, -0.9590447,
4.8459644, -6.934994, -9.45412, -1.4752215, 10.538533, 11.418711, 19.534836};
conv_param->input_h_ = 19;
conv_param->input_w_ = 10;
conv_param->input_channel_ = 1;
conv_param->output_h_ = 7;
conv_param->output_w_ = 4;
conv_param->stride_h_ = conv_param->stride_w_ = 3;
conv_param->pad_h_ = conv_param->pad_w_ = 0;
float out[28] = {0};
Conv1x1InputPackFp32(in, out, conv_param);
CompareOutputData(out, correct, 28, 0.0001);
delete conv_param;
}
TEST_F(TestConv1x1Fp32, Input1x1PrePack3) {
auto conv_param = new ConvParameter();
conv_param->input_channel_ = 2;
conv_param->input_h_ = conv_param->input_w_ = 3;
conv_param->output_h_ = conv_param->output_w_ = 3;
conv_param->stride_h_ = conv_param->stride_w_ = 2;
conv_param->pad_h_ = conv_param->pad_w_ = 1;
float in[] = {1.6767339, 12.25904, 19.018835, 3.0790641, -9.252135, -8.685675, 3.6115494, 3.2282279, 17.025112,
-5.052577, 12.750252, 12.701241, -8.9477215, -9.080522, 19.03931, -6.501229, -4.122992, 9.540845};
float out[18] = {0};
float correct[] = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 17.025112,
-5.052577, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
Conv1x1InputPackFp32(in, out, conv_param);
EXPECT_EQ(0, lite::CompareOutputData(out, correct, 18));
delete conv_param;
}
TEST_F(TestConv1x1Fp32, Input1x1PrePack4) {
auto conv_param = new ConvParameter();
conv_param->input_channel_ = 6;
conv_param->input_h_ = conv_param->input_w_ = 3;
conv_param->output_h_ = conv_param->output_w_ = 3;
conv_param->stride_h_ = conv_param->stride_w_ = 2;
conv_param->pad_h_ = conv_param->pad_w_ = 1;
float in[] = {4.1795, 13.142, -3.593, 16.505, 19.899, 8.5562, 19.969, -6.235, -2.380, -9.027, 9.5542,
18.974, 23.622, 8.3608, 47.325, -14.36, 15.370, 4.3049, -0.784, 37.925, -0.081, 6.1298,
0.6721, -1.517, 37.998, 13.719, 11.029, 1.7127, -1.770, 41.903, 9.0560, 14.988, 3.1866,
0.0562, 8.1381, 9.1391, 14.530, -14.10, -8.115, -8.071, -8.158, 7.7566, 19.250, 17.923,
13.584, 3.3293, 9.7341, 18.834, -1.514, -0.293, 18.686, 0.0873, 4.2010, -2.253};
float correct[] = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 37.998, 13.719, 11.029, 1.7127,
-1.770, 41.903, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
float out[54] = {0};
Conv1x1InputPackFp32(in, out, conv_param);
EXPECT_EQ(0, lite::CompareOutputData(out, correct, 54));
delete conv_param;
}
TEST_F(TestConv1x1Fp32, Conv1x1WeightTest1) {
ConvParameter *conv_param = new ConvParameter();
float in[] = {0.214637, 0.3815, 0.811557, 0.982146, 0.09123, 0.687198, 0.02742, 0.3360, 0.853275,
0.674123, 0.81337, 0.57188, 0.706416, 0.2740942, 0.9045, 0.07155, 0.130864, 0.037712,
0.5369175, 0.97283, 0.92133, 0.3588165, 0.7432479, 0.7886823, 0.870324, 0.230946, 0.343969,
0.095415, 0.50036, 0.396918, 0.09029, 0.934583, 0.91616, 0.206713, 0.9756054, 0.614025,
0.432057, 0.1493, 0.6787, 0.10642, 0.736823, 0.377668, 0.2464896, 0.93152, 0.315917,
0.35745, 0.52233, 0.0263, 0.339392, 0.99447, 0.49129, 0.675686, 0.75703, 0.6665356,
0.0491, 0.1070, 0.18899, 0.929156, 0.4633427, 0.08585, 0.040709, 0.2478724, 0.5238441,
0.0579918, 0.531636, 0.085524, 0.640923, 0.336395, 0.218651, 0.630491};
float co[] = {0.214637, 0.81337, 0.92133, 0.09029, 0.3815, 0.57188, 0.3588165, 0.934583, 0.811557,
0.706416, 0.7432479, 0.91616, 0.982146, 0.2740942, 0.7886823, 0.206713, 0.09123, 0.9045,
0.870324, 0.9756054, 0.687198, 0.07155, 0.230946, 0.614025, 0.02742, 0.130864, 0.343969,
0.432057, 0.3360, 0.037712, 0.095415, 0.1493, 0.853275, 0.5369175, 0.50036, 0.6787,
0.674123, 0.97283, 0.396918, 0.10642, 0, 0, 0, 0, 0,
0, 0, 0, 0.736823, 0.49129, 0.040709, 0, 0.377668, 0.675686,
0.2478724, 0, 0.2464896, 0.75703, 0.5238441, 0, 0.93152, 0.6665356, 0.0579918,
0, 0.315917, 0.0491, 0.531636, 0, 0.35745, 0.1070, 0.085524, 0,
0.52233, 0.18899, 0.640923, 0, 0.0263, 0.929156, 0.336395, 0, 0.339392,
0.4633427, 0.218651, 0, 0.99447, 0.08585, 0.630491, 0, 0, 0,
0, 0, 0, 0, 0, 0};
conv_param->input_channel_ = 10;
conv_param->output_channel_ = 7;
float out[96] = {0};
Pack1x1WeightFp32(in, out, conv_param);
EXPECT_EQ(0, lite::CompareOutputData(out, co, 96));
delete conv_param;
}
TEST_F(TestConv1x1Fp32, PostConvFuncC4Test1) {
float in[] = {-9.389655, -5.83877, 7.5724425, -1.4675674, -5.456284, 0.7406984, 16.965645, 10.888806,
-0.8614793, -4.404605, 10.917422, 0.11158327, -5.2733865, -0.96367484, -4.731118, -7.576815,
-6.1621623, -0.6315082, -9.140878, 9.266748, 13.644127, 8.206812, 7.091153, -0.50162584,
2.0889723, 6.6916203, -5.3981733, 11.997365, -9.254076, -5.5964484, -5.981469, -0.51114964,
-2.6300175, 0, 0, 0, -7.2690716, 0, 0, 0,
11.1863365, 0, 0, 0, -3.4595785, 0, 0, 0,
-8.344107, 0, 0, 0, -3.792715, 0, 0, 0,
-7.0394287, 0, 0, 0, -2.7693212, 0, 0, 0};
float bias[] = {0.7429814, 0.4863214, 0.9888875, 0.19727881, 0.009881007, 0, 0, 0};
float out[40] = {0};
float no[] = {-8.646674, -5.3524485, 8.56133, -1.2702886, -2.6201365, -4.7133026, 1.2270198, 17.954533,
11.086085, -7.2591906, -0.11849791, -3.9182835, 11.90631, 0.3088621, 11.196218, -4.530405,
-0.47735345, -3.7422307, -7.379536, -3.4496975, -5.419181, -0.14518678, -8.15199, 9.464027,
-8.334226, 14.387108, 8.693133, 8.080041, -0.30434704, -3.782834, 2.8319538, 7.177942,
-4.409286, 12.194644, -7.0295477, -8.511095, -5.110127, -4.992582, -0.31387085, -2.7594402};
PostConvFuncFp32C4(in, out, bias, 5, 8, 5, false, false);
CompareOutputData(out, no, 40, 0.0001);
float relu[] = {0, 0, 8.56133, 0, 0, 0, 1.2270198, 17.954533, 11.086085, 0,
0, 0, 11.90631, 0.3088621, 11.196218, 0, 0, 0, 0, 0,
0, 0, 0, 9.464027, 0, 14.387108, 8.693133, 8.080041, 0, 0,
2.8319538, 7.177942, 0, 12.194644, 0, 0, 0, 0, 0, 0};
PostConvFuncFp32C4(in, out, bias, 5, 8, 5, true, false);
CompareOutputData(out, relu, 40, 0.0001);
float corr_relu6[] = {0, 0, 6, 0, 0, 0, 1.2270198, 6, 6, 0, 0, 0, 6, 0.3088621, 6, 0, 0, 0, 0, 0,
0, 0, 0, 6, 0, 6, 6, 6, 0, 0, 2.8319538, 6, 0, 6, 0, 0, 0, 0, 0, 0};
PostConvFuncFp32C4(in, out, bias, 5, 8, 5, false, true);
CompareOutputData(out, corr_relu6, 40, 0.0001);
float nob_relu[] = {0, 0, 7.5724425, 0, 0, 0, 0.7406984, 16.965645,
10.888806, 0, 0, 0, 10.917422, 0.11158327, 11.1863365, 0,
0, 0, 0, 0, 0, 0, 0, 9.266748,
0, 13.644127, 8.206812, 7.091153, 0, 0, 2.0889723, 6.6916203,
0, 11.997365, 0, 0, 0, 0, 0, 0};
PostConvFuncFp32C4(in, out, nullptr, 5, 8, 5, true, false);
CompareOutputData(out, nob_relu, 40, 0.0001);
}
TEST_F(TestConv1x1Fp32, PostConvFuncC4Test2) {
float in[] = {-9.389655, -5.83877, 7.5724425, -1.4675674, -5.456284, 0.7406984, 16.965645, 10.888806,
-0.8614793, -4.404605, 10.917422, 0.11158327, -5.2733865, -0.96367484, -4.731118, -7.576815,
-6.1621623, -0.6315082, -9.140878, 9.266748, 13.644127, 8.206812, 7.091153, -0.50162584,
2.0889723, 6.6916203, -5.3981733, 11.997365, -9.254076, -5.5964484, -5.981469, -0.51114964,
-2.6300175, 0, 0, 0, -7.2690716, 0, 0, 0,
11.1863365, 0, 0, 0, -3.4595785, 0, 0, 0,
-8.344107, 0, 0, 0, -3.792715, 0, 0, 0,
-7.0394287, 0, 0, 0, -2.7693212, 0, 0, 0};
float bias[] = {0.7429814, 0.4863214, 0.9888875, 0.19727881, 0.009881007, 0, 0, 0};
float corr[] = {-8.646674, -5.3524485, 8.56133, -1.2702886, -2.6201365, -4.7133026, 1.2270198, 17.954533,
11.086085, -7.2591906, -0.11849791, -3.9182835, 11.90631, 0.3088621, 11.196218, -4.530405,
-0.47735345, -3.7422307, -7.379536, -3.4496975, -5.419181, -0.14518678, -8.15199, 9.464027,
-8.334226, 14.387108, 8.693133, 8.080041, -0.30434704, -3.782834, 2.8319538, 7.177942,
-4.409286, 12.194644, -7.0295477, -8.511095, -5.110127, -4.992582, -0.31387085, -2.7594402};
float out[40] = {0};
int thread_count_ = 2;
int thread_oc4_stride_ = 1;
int output_channel = 5;
int plane_size = 8;
for (int i = 0; i < thread_count_; i++) {
int cur_oc = MSMIN(thread_oc4_stride_ * 4, output_channel - i * thread_oc4_stride_ * 4);
if (cur_oc <= 0) break;
PostConvFuncFp32C4(in + thread_oc4_stride_ * i * 8 * 4, out + i * i * thread_oc4_stride_ * 4,
bias + i * thread_oc4_stride_ * 4, cur_oc, plane_size, output_channel, false, false);
}
CompareOutputData(out, corr, 40, 0.0001);
}
int Conv1x1TestInit1(std::vector<lite::tensor::Tensor *> *inputs_, std::vector<lite::tensor::Tensor *> *outputs_,
ConvParameter *conv_param, float **correct) {
lite::tensor::Tensor *in_t =
new lite::tensor::Tensor(kNumberTypeFloat, {1, 2, 3, 4}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
in_t->MallocData();
float in[] = {12.216284, 3.3466918, 15.327419, 5.234958, 0.804376, 9.952188, 14.727955, -8.080715,
13.71383, 8.055829, 6.5845337, -9.25232, -4.24519, 11.550042, 9.262012, 1.2780352,
6.7263746, -3.9301445, 3.764492, -8.602078, -3.3558068, 13.619035, -2.6694393, 3.2008505};
memcpy(in_t->Data(), in, sizeof(float) * 24);
inputs_->push_back(in_t);
lite::tensor::Tensor *weight_t =
new lite::tensor::Tensor(kNumberTypeFloat, {3, 1, 1, 4}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
weight_t->MallocData();
float weight[] = {-0.7308652, 0.5257509, -0.87825793, -1.123181, -1.2206168, 0.562695,
1.5382664, -0.5020635, 0.8591602, -0.26410004, 1.1262615, 0.073132955}; /* nhwc */
memcpy(weight_t->Data(), weight, sizeof(float) * 12);
inputs_->push_back(weight_t);
lite::tensor::Tensor *bias_t =
new lite::tensor::Tensor(kNumberTypeFloat, {3}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
bias_t->MallocData();
float bias[] = {2, 2, 2};
memcpy(bias_t->Data(), bias, sizeof(float) * 3);
inputs_->push_back(bias_t);
lite::tensor::Tensor *out_t =
new lite::tensor::Tensor(kNumberTypeFloat, {1, 2, 3, 3}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
out_t->MallocData();
outputs_->push_back(out_t);
*correct = reinterpret_cast<float *>(malloc(out_t->ElementsNum() * sizeof(float)));
float co[] = {2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 1.3731456, 1.6877825, 12.427691, 2., 2., 2.};
memcpy(*correct, co, out_t->ElementsNum() * sizeof(float));
conv_param->kernel_h_ = conv_param->kernel_w_ = 1;
conv_param->stride_h_ = conv_param->stride_w_ = 2;
conv_param->dilation_h_ = conv_param->dilation_w_ = 1;
conv_param->pad_h_ = conv_param->pad_w_ = 1;
conv_param->is_relu_ = conv_param->is_relu6_ = false;
return out_t->ElementsNum();
}
TEST_F(TestConv1x1Fp32, Conv1x1Test1) {
std::vector<lite::tensor::Tensor *> inputs_;
std::vector<lite::tensor::Tensor *> outputs_;
auto conv_param = new ConvParameter();
lite::Context *ctx = new lite::Context();
ctx->threadNum = 1;
float *correct;
int total_size = Conv1x1TestInit1(&inputs_, &outputs_, conv_param, &correct);
kernel::Convolution1x1CPUKernel *conv1x1 =
new kernel::Convolution1x1CPUKernel(reinterpret_cast<OpParameter *>(conv_param), inputs_, outputs_, ctx);
conv1x1->Init();
conv1x1->Run();
CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001);
delete conv_param;
delete conv1x1;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;
free(correct);
}
int Conv1x1TestInit2(std::vector<lite::tensor::Tensor *> *inputs_, std::vector<lite::tensor::Tensor *> *outputs_,
ConvParameter *conv_param, float **correct) {
size_t buffer_size;
lite::tensor::Tensor *in_t = new lite::tensor::Tensor(kNumberTypeFloat, {1, 300, 300, 24}, schema::Format_NHWC,
static_cast<schema::NodeType>(1));
in_t->MallocData();
std::string input_path = "./conv/conv1x1fp32_input1_nhwc.bin";
auto in = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &buffer_size));
memcpy(in_t->Data(), in, buffer_size);
inputs_->push_back(in_t);
lite::tensor::Tensor *weight_t =
new lite::tensor::Tensor(kNumberTypeFloat, {40, 1, 1, 24}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
weight_t->MallocData();
std::string weight_path = "./conv/conv1x1fp32_weight1_nhwc.bin";
auto weight = reinterpret_cast<float *>(mindspore::lite::ReadFile(weight_path.c_str(), &buffer_size));
memcpy(weight_t->Data(), weight, buffer_size);
inputs_->push_back(weight_t);
lite::tensor::Tensor *bias_t =
new lite::tensor::Tensor(kNumberTypeFloat, {40}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
bias_t->MallocData();
std::string bias_path = "./conv/conv1x1fp32_bias1_nhwc.bin";
auto bias = mindspore::lite::ReadFile(bias_path.c_str(), &buffer_size);
memcpy(bias_t->Data(), bias, buffer_size);
inputs_->push_back(bias_t);
lite::tensor::Tensor *out_t = new lite::tensor::Tensor(kNumberTypeFloat, {1, 300, 300, 40}, schema::Format_NHWC,
static_cast<schema::NodeType>(1));
out_t->MallocData();
outputs_->push_back(out_t);
std::string out_path = "./conv/conv1x1fp32_output1_nhwc.bin";
auto out_nhwc = mindspore::lite::ReadFile(out_path.c_str(), &buffer_size);
*correct = reinterpret_cast<float *>(malloc(buffer_size));
memcpy(*correct, out_nhwc, buffer_size);
conv_param->kernel_h_ = conv_param->kernel_w_ = 1;
conv_param->stride_h_ = conv_param->stride_w_ = 1;
conv_param->dilation_h_ = conv_param->dilation_w_ = 1;
conv_param->pad_h_ = conv_param->pad_w_ = 0;
conv_param->is_relu_ = false;
conv_param->is_relu6_ = false;
return out_t->ElementsNum();
}
TEST_F(TestConv1x1Fp32, Conv1x1Test2) {
std::vector<lite::tensor::Tensor *> inputs_;
std::vector<lite::tensor::Tensor *> outputs_;
auto conv_param = new ConvParameter();
lite::Context *ctx = new lite::Context();
ctx->threadNum = 2;
float *correct;
int total_size = Conv1x1TestInit2(&inputs_, &outputs_, conv_param, &correct);
kernel::Convolution1x1CPUKernel *conv1x1 =
new kernel::Convolution1x1CPUKernel(reinterpret_cast<OpParameter *>(conv_param), inputs_, outputs_, ctx);
conv1x1->Init();
conv1x1->Run();
CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001);
/* running warm up */
for (int i = 0; i < 0; i++) {
conv1x1->Run();
}
/* running time cost */
int loop_count = 1;
auto time_start = mindspore::lite::GetTimeUs();
for (int i = 0; i < loop_count; i++) {
conv1x1->Run();
}
auto time_end = mindspore::lite::GetTimeUs();
auto cost = time_end - time_start;
uint64_t time_avg = cost / loop_count;
printf("1x1 average time : %f ms\n", time_avg / 1000.0f);
delete conv_param;
delete conv1x1;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;
free(correct);
}
} // namespace mindspore

View File

@ -0,0 +1,548 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include "utils/log_adapter.h"
#include "common/common_test.h"
#include "src/common/file_utils.h"
#include "mindspore/lite/src/ops/ops.h"
#include "mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.h"
#include "mindspore/lite/src/runtime/kernel/arm/opclib/fp32/deconv.h"
namespace mindspore {
class TestDeConvolutionFp32 : public mindspore::Common {
public:
TestDeConvolutionFp32() {}
};
TEST_F(TestDeConvolutionFp32, DeConvWeightC4x4Pack1) {
float in[] = {0.43005997, -0.01335099, -0.43214464, -0.2569654, -0.14664753, -0.09249142, 0.42330834, 0.17678244,
-0.26901904, 0.29920393, -0.25139654, 0.04580693, 0.08898365, -0.29335496, 0.1332809, 0.06561925,
0.50099367, -0.45963442, -0.17191549, -0.1517635, -0.54385597, 0.20007996, 0.3174582, -0.13803318,
-0.10295965, 0.03531377, -0.05687982, 0.09801699, -0.1504936, 0.27094424, -0.15454058, 0.25500196,
0.03428256, 0.1711275, -0.28639716, 0.05972834, 0.1301975, 0.09662235, -0.26297596, 0.25723842,
0.37723106, -0.49640322, 0.21951586, -0.25885767, -0.44244745, 0.04153876, 0.41899854, 0.07920247,
0.31681255, 0.3300002, 0.23956111, 0.13012694, 0.26047292, 0.0851135, -0.185474, 0.306445,
0.20750166, -0.13887969, -0.15064844, -0.08100204, 0.08206631, 0.3151005, 0.26807567, -0.6340778,
0.1019667, 0.14200483, -0.56623703, 0.47877932, 0.13249867, 0.3862773, 0.7469436, 0.14524518,
0.42495733, 0.08011179, 0.19647601, -0.03030056, 0.12770538, -0.32460797, -0.2103409, 0.33223677,
-0.47110182, -0.5424416, 0.18340437, 0.3781465, 0.04931778, 0.17888185, 0.04547426, -0.01483545,
0.29989168, 0.12018301, 0.00213889, 0.21470474, -0.4031554, -0.10013647, -0.12780161, -0.28953925,
0.05002394, 0.5460746, -0.7209624, 0.32692385, -0.09215609, -0.07226299, 0.47478926, -0.6297518,
0.22869332, -0.33726704, -0.24732, 0.07623845, 0.38042688, -0.18950662, -0.16825019, 0.49407697,
-0.10242693, 0.59533256, -0.11732046, 0.7062394, 0.35063574, -0.17253993, -0.14738934, 0.26435736};
float co[] = {
0.43005997, -0.01335099, -0.43214464, -0.2569654, -0.10295965, 0.03531377, -0.05687982, 0.09801699, 0.31681255,
0.3300002, 0.23956111, 0.13012694, 0.42495733, 0.08011179, 0.19647601, -0.03030056, 0.05002394, 0.5460746,
-0.7209624, 0.32692385, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000,
0.000, 0.000, 0.000, 0.000, 0.000, 0.42330834, 0.17678244, -0.26901904, 0.29920393,
-0.15454058, 0.25500196, 0.03428256, 0.1711275, -0.185474, 0.306445, 0.20750166, -0.13887969, -0.2103409,
0.33223677, -0.47110182, -0.5424416, 0.47478926, -0.6297518, 0.22869332, -0.33726704, 0.000, 0.000,
0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000,
0.000, 0.08898365, -0.29335496, 0.1332809, 0.06561925, 0.1301975, 0.09662235, -0.26297596, 0.25723842,
0.08206631, 0.3151005, 0.26807567, -0.6340778, 0.04931778, 0.17888185, 0.04547426, -0.01483545, 0.38042688,
-0.18950662, -0.16825019, 0.49407697, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000,
0.000, 0.000, 0.000, 0.000, 0.000, 0.000, -0.17191549, -0.1517635, -0.54385597,
0.20007996, 0.21951586, -0.25885767, -0.44244745, 0.04153876, -0.56623703, 0.47877932, 0.13249867, 0.3862773,
0.00213889, 0.21470474, -0.4031554, -0.10013647, -0.11732046, 0.7062394, 0.35063574, -0.17253993, 0.000,
0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000,
0.000, 0.000, -0.14664753, -0.09249142, 0.000, 0.000, -0.1504936, 0.27094424, 0.000,
0.000, 0.26047292, 0.0851135, 0.000, 0.000, 0.12770538, -0.32460797, 0.000, 0.000,
-0.09215609, -0.07226299, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000,
0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, -0.25139654, 0.04580693,
0.000, 0.000, -0.28639716, 0.05972834, 0.000, 0.000, -0.15064844, -0.08100204, 0,
0, 0.18340437, 0.3781465, 0.000, 0.000, -0.24732, 0.07623845, 0.000, 0.000,
0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000,
0.000, 0.000, 0.000, 0.50099367, -0.45963442, 0.000, 0.000, 0.37723106, -0.49640322,
0.000, 0.000, 0.1019667, 0.14200483, 0.000, 0.000, 0.29989168, 0.12018301, 0.000,
0.000, -0.10242693, 0.59533256, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000,
0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.3174582,
-0.13803318, 0.000, 0.000, 0.41899854, 0.07920247, 0.000, 0.000, 0.7469436, 0.14524518,
0.000, 0.000, -0.12780161, -0.28953925, 0.000, 0.000, -0.14738934, 0.26435736, 0.000,
0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000,
0.000, 0.000, 0.000, 0.00};
float dst[256] = {0};
PackDeConvWeightFp32(in, dst, 5, 6, 2 * 2);
EXPECT_EQ(0, lite::CompareOutputData(dst, co, 256));
}
TEST_F(TestDeConvolutionFp32, DeConvWeightC4x4Pack2) {
float in[] = {4.1795, 13.142, -3.593, 16.505, 19.969, -6.235, -2.380, -9.027, 23.622, 8.3608, 47.325, -14.36,
-0.784, 37.925, -0.081, 6.1298, 37.998, 13.719, 11.029, 1.7127, 9.0560, 14.988, 3.1866, 0.0562,
14.530, -14.10, -8.115, -8.071, 19.250, 17.923, 13.584, 3.3293, -1.514, -0.293, 18.686, 0.0873};
float co[] = {4.1795, 13.142, -3.593, 0, -2.380, -9.027, 23.622, 0, -0.784, 37.925, -0.081, 0, 11.029,
1.7127, 9.0560, 0, 14.530, -14.10, -8.115, 0, 13.584, 3.3293, -1.514, 0, 0, 0,
0, 0, 0, 0, 0, 0, 16.505, 19.969, -6.235, 0, 8.3608, 47.325, -14.36,
0, 6.1298, 37.998, 13.719, 0, 14.988, 3.1866, 0.0562, 0, -8.071, 19.250, 17.923, 0,
-0.293, 18.686, 0.0873, 0, 0, 0, 0, 0, 0, 0, 0, 0};
float dst[64] = {0};
PackDeConvWeightFp32(in, dst, 6, 3, 2 * 1);
EXPECT_EQ(0, lite::CompareOutputData(dst, co, 64));
}
TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test1) {
float in[] = {-9.389655, -5.83877, 7.5724425, -1.4675674, -2.6300175, 0, 0, 0,
-5.456284, 0.7406984, 16.965645, 10.888806, -7.2690716, 0, 0, 0,
-0.8614793, -4.404605, 10.917422, 0.11158327, 11.1863365, 0, 0, 0,
-5.2733865, -0.96367484, -4.731118, -7.576815, -3.4595785, 0, 0, 0,
-6.1621623, -0.6315082, -9.140878, 9.266748, -8.344107, 0, 0, 0,
13.644127, 8.206812, 7.091153, -0.50162584, -3.792715, 0, 0, 0,
2.0889723, 6.6916203, -5.3981733, 11.997365, -7.0394287, 0, 0, 0,
-9.254076, -5.5964484, -5.981469, -0.51114964, -2.7693212, 0, 0, 0};
float bias[] = {0.7429814, 0.4863214, 0.9888875, 0.19727881, 0.009881007, 0, 0, 0};
float out[40] = {0};
float no[] = {-8.646674, -5.3524485, 8.56133, -1.2702886, -2.6201365, -4.7133026, 1.2270198, 17.954533,
11.086085, -7.2591906, -0.11849791, -3.9182835, 11.90631, 0.3088621, 11.196218, -4.530405,
-0.47735345, -3.7422307, -7.379536, -3.4496975, -5.419181, -0.14518678, -8.15199, 9.464027,
-8.334226, 14.387108, 8.693133, 8.080041, -0.30434704, -3.782834, 2.8319538, 7.177942,
-4.409286, 12.194644, -7.0295477, -8.511095, -5.110127, -4.992582, -0.31387085, -2.7594402};
PostConvFuncFp32C8(in, out, bias, 5, 8, 5, false, false);
CompareOutputData(out, no, 40, 0.0001);
float relu[] = {0, 0, 8.56133, 0, 0, 0, 1.2270198, 17.954533, 11.086085, 0,
0, 0, 11.90631, 0.3088621, 11.196218, 0, 0, 0, 0, 0,
0, 0, 0, 9.464027, 0, 14.387108, 8.693133, 8.080041, 0, 0,
2.8319538, 7.177942, 0, 12.194644, 0, 0, 0, 0, 0, 0};
PostConvFuncFp32C8(in, out, bias, 5, 8, 5, true, false);
CompareOutputData(out, relu, 40, 0.0001);
float corr_relu6[] = {0, 0, 6, 0, 0, 0, 1.2270198, 6, 6, 0, 0, 0, 6, 0.3088621, 6, 0, 0, 0, 0, 0,
0, 0, 0, 6, 0, 6, 6, 6, 0, 0, 2.8319538, 6, 0, 6, 0, 0, 0, 0, 0, 0};
PostConvFuncFp32C8(in, out, bias, 5, 8, 5, false, true);
CompareOutputData(out, corr_relu6, 40, 0.0001);
float nob_relu[] = {0, 0, 7.5724425, 0, 0, 0, 0.7406984, 16.965645,
10.888806, 0, 0, 0, 10.917422, 0.11158327, 11.1863365, 0,
0, 0, 0, 0, 0, 0, 0, 9.266748,
0, 13.644127, 8.206812, 7.091153, 0, 0, 2.0889723, 6.6916203,
0, 11.997365, 0, 0, 0, 0, 0, 0};
PostConvFuncFp32C8(in, out, nullptr, 5, 8, 5, true, false);
CompareOutputData(out, nob_relu, 40, 0.0001);
}
int DeConvTestInit1(std::vector<lite::tensor::Tensor *> *inputs_, std::vector<lite::tensor::Tensor *> *outputs_,
ConvParameter *conv_param, float **correct) {
std::vector<int> in_dims_nhwc = {1, 5, 7, 2};
lite::tensor::Tensor *in_t =
new lite::tensor::Tensor(kNumberTypeFloat, in_dims_nhwc, schema::Format_NHWC, static_cast<schema::NodeType>(1));
in_t->MallocData();
float in_nchw[] = {
0.39451003, 0.15045597, 0.5367726, 0.62690735, 0.113554195, 0.5402554, 0.5522764, 0.044319753, 0.25721782,
0.41789535, 0.6717553, 0.72254324, 0.15164013, 0.93585724, 0.33732107, 0.14599903, 0.20070823, 0.640386,
0.74077445, 0.088589266, 0.08755991, 0.4489046, 0.7409207, 0.7373529, 0.8887349, 0.045393247, 0.6483991,
0.7542141, 0.8730748, 0.5480396, 0.19493233, 0.41220096, 0.77443165, 0.9909433, 0.8081086, 0.91432786,
0.97605807, 0.48640794, 0.7690306, 0.9381521, 0.44073114, 0.27656683, 0.0725352, 0.53911537, 0.994353,
0.2642501, 0.29840338, 0.38820496, 0.37829784, 0.105839334, 0.07713295, 0.45629853, 0.9290373, 0.56323594,
0.59976774, 0.48325357, 0.102543674, 0.35449505, 0.3158472, 0.02927611, 0.44739273, 0.0516185, 0.12340133,
0.13908496, 0.54970616, 0.74672216, 0.673308, 0.6400629, 0.26790652, 0.98673576}; /* nhwc */
PackNCHWToNHWCFp32(in_nchw, in_t->Data(), in_t->Batch(), in_t->Width() * in_t->Height(), in_t->Channel());
inputs_->push_back(in_t);
std::vector<int> weight_dims_nhwc = {2, 3, 3, 6};
lite::tensor::Tensor *weight_t =
new lite::tensor::Tensor(kNumberTypeFloat, weight_dims_nhwc, schema::Format_NHWC, static_cast<schema::NodeType>(1));
weight_t->MallocData();
float weight_nchw[] = {
0.061163727, -0.06261389, 0.07708351, -0.019354159, -0.3859104, -0.082844816, -0.21268463, -0.15746808,
-0.096376516, 0.016681675, 0.1364329, -0.007941234, -0.10095563, 0.32489842, -0.042597733, 0.2701167,
-0.1415933, 0.007270595, -0.34188282, -0.3374504, -0.26375315, -0.075536035, 0.11136466, -0.2239981,
-0.07840504, -0.23905717, -0.10171707, -0.11058277, 0.363706, -0.09807812, -0.05729029, 0.0018888254,
-0.29443327, 0.13365538, 0.0453783, -0.31048688, 0.07062391, 0.16674924, 0.2268152, -0.18341774,
0.10190555, 0.08567296, 0.13261533, -0.40412605, 0.13981377, -0.08217087, -0.050615843, -0.05403921,
-0.028555218, 0.2651543, 0.10668221, -0.013095176, 0.09588115, 0.044287443, -0.009692867, 0.06717065,
-0.29928264, -0.09110823, -0.07987715, -0.15888898, 0.041994736, 0.086504236, -0.19046812, 0.20323305,
0.08014105, 0.009099235, 0.2525443, -0.010155359, 0.039532702, 0.20266832, 0.0045211455, -0.14146733,
-0.07135475, -0.011584315, 0.1640728, 0.13032198, 0.18829331, -0.27231383, -0.15681058, -0.14862166,
-0.084803745, -0.020582702, -0.0681792, 0.06789135, 0.13603394, 0.090862036, -0.08380498, -0.16875166,
-0.2570391, -0.013280135, 0.24033138, -0.08921211, 0.2722501, 0.24916205, -0.20001566, -0.11610521,
0.06060236, 0.10848369, -0.4512424, 0.023834296, 0.1643943, -0.25290534, 0.066953085, -0.11685201,
-0.4159784, 0.37839416, -0.11141268, -0.15986018}; /* nhwc */
PackNCHWToNHWCFp32(weight_nchw, weight_t->Data(), weight_t->Batch(), weight_t->Width() * weight_t->Height(),
weight_t->Channel());
inputs_->push_back(weight_t);
lite::tensor::Tensor *bias_t =
new lite::tensor::Tensor(kNumberTypeFloat, {6}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
bias_t->MallocData();
float bias[] = {-0.19064677, -0.0034778118, 0.63741624, -1.0311537, -1.0288948, 0.71384084};
memcpy(bias_t->Data(), bias, sizeof(float) * 6);
inputs_->push_back(bias_t);
std::vector<int> output_nhwc_dims = {1, 9, 13, 6};
lite::tensor::Tensor *out_t =
new lite::tensor::Tensor(kNumberTypeFloat, output_nhwc_dims, schema::Format_NHWC, static_cast<schema::NodeType>(1));
out_t->MallocData();
outputs_->push_back(out_t);
*correct = reinterpret_cast<float *>(malloc(out_t->ElementsNum() * sizeof(float)));
float nchw_co[] = {
-0.4159262, -0.46044537, -0.32667404, -0.4129007, -0.43664578, -0.39459872, -0.49400482, -0.4524444,
-0.30940545, -0.3997266, -0.4343413, -0.3413178, -0.42586732, -0.17157906, -0.4016143, -0.1097983,
-0.61039054, -0.19246969, -0.6629166, -0.24715163, -0.36829865, -0.1525711, -0.50477314, -0.22101344,
-0.4834266, -0.2868756, -0.21354413, -0.25993955, -0.33297282, -0.3962972, -0.43134302, -0.4203356,
-0.47099167, -0.32945585, -0.4933193, -0.3362223, -0.28017497, -0.31746963, -0.5820211, -0.2053628,
-0.23829184, -0.1884751, -0.36922038, -0.15235345, -0.6430171, -0.25126106, -0.63569427, -0.28716096,
-0.44492853, -0.14620401, -0.63435787, -0.27831206, -0.32927662, -0.24526191, -0.25315046, -0.2604547,
-0.30455, -0.37681228, -0.5119872, -0.4569657, -0.521509, -0.39786643, -0.27274203, -0.33900544,
-0.26303798, -0.25582826, -0.22533318, -0.2295449, -0.2498781, -0.20773302, -0.3777015, -0.2648021,
-0.50503045, -0.23136339, -0.45421264, -0.18984585, -0.23228307, -0.20156652, -0.3720746, -0.29076657,
-0.5048918, -0.35140067, -0.5004279, -0.32178527, -0.5359573, -0.3105652, -0.24390095, -0.28274524,
-0.44499388, -0.27840495, -0.49156278, -0.29778862, -0.34227157, -0.27404356, -0.5907216, -0.24148186,
-0.69942933, -0.3086446, -0.40131485, -0.16459012, -0.48982328, -0.33233505, -0.38212818, -0.2830558,
-0.5386851, -0.34576517, -0.4460499, -0.39519656, -0.3255192, -0.39476353, -0.40350133, -0.4050802,
-0.5406344, -0.40009072, -0.5944617, -0.42084867, -0.58132195, 0.11541255, 0.24717134, 0.035492875,
0.09734866, 0.16597912, 0.12381038, 0.1923936, 0.22568025, 0.023888497, 0.085535035, 0.16757454,
0.0050217994, 0.17314728, -0.043344263, 0.22266465, 0.057929777, 0.315026, 0.059421062, 0.3274499,
0.02406001, 0.18286264, 0.107178226, 0.17828721, -0.026181899, 0.23815396, 0.07757285, 0.010184985,
0.10768472, 0.07461695, 0.21580729, 0.12219772, 0.016947635, 0.21209088, -0.019231271, 0.22824496,
0.060270205, 0.041847467, 0.006466368, 0.29673898, 0.04507852, 0.18171927, -0.0113601275, 0.332155,
0.005798064, 0.29595143, 0.0644246, 0.349865, 0.04176835, 0.20181134, 0.036958598, 0.37659368,
-0.0836041, 0.105042435, -0.008922746, 0.04317373, 0.08832521, 0.057098098, 0.1759837, 0.19514789,
0.07342724, 0.23147877, 0.12975746, 0.019213844, 0.1296622, 0.020062651, 0.01870161, 0.1208442,
0.105693996, 0.20719647, 0.096077755, 0.3124894, 0.033647023, 0.26888633, -0.06377239, 0.09272936,
0.07928991, 0.06689171, 0.09909828, 0.14132921, -0.0038207127, 0.23364612, -0.015699724, 0.23287944,
-0.10473035, 0.28497344, 0.06822525, 0.0067269485, -0.0401484, 0.20666184, -0.074035384, 0.24031198,
0.06368647, 0.37245232, 0.012040168, 0.3706034, -0.020015769, 0.35215783, -0.018986963, 0.24762997,
0.14907081, 0.18981782, 0.061614163, 0.43125582, 0.07961907, 0.27877036, 0.048327066, 0.16899693,
0.16380924, 0.052272163, 0.14616457, 0.12360795, 0.08904207, 0.24163374, -0.043546468, 0.31575742,
0.1325127, 0.24905476, 0.8535125, 0.4158996, 0.8379569, 0.36076424, 0.7887811, 0.4375921,
0.85203487, 0.40125692, 0.8267099, 0.37313673, 0.78056836, 0.39070883, 0.750996, 0.39142087,
0.22870147, 0.36334175, 0.22776845, 0.2842683, 0.17623127, 0.14350969, -0.049721956, 0.22356126,
0.21368039, 0.38709402, 0.13516903, 0.14409906, 0.6560098, 0.65856576, 0.76757306, 0.5310113,
0.87118506, 0.25672826, 0.76198256, 0.39929584, 0.77406937, 0.43344593, 0.7274, 0.47634512,
0.8128686, 0.50098574, 0.39502823, 0.44564128, 0.24981359, 0.31671798, 0.15317863, 0.21069425,
0.13331234, 0.16383857, 0.28979823, 0.50662756, 0.46699578, 0.32232434, 0.6949107, 0.5320594,
0.668199, 0.6280134, 0.745686, 0.54090333, 0.88366413, 0.25842816, 0.8259659, 0.38957846,
0.7602142, 0.510612, 0.7381607, 0.38837627, 0.1904087, 0.33691993, 0.11685282, 0.26914072,
-0.06617683, 0.046009183, 0.0700444, 0.356119, 0.24937916, 0.30769932, 0.06569201, 0.28872308,
0.70671666, 0.4991707, 0.78667766, 0.36038262, 0.7790032, 0.32292485, 0.7419024, 0.48524532,
0.7267125, 0.46316653, 0.7193444, 0.4372312, 0.7446447, 0.2186315, 0.03533274, 0.216304,
0.25036755, 0.33977476, 0.3434924, 0.27370954, 0.16213486, 0.29132545, 0.078781545, 0.13724238,
-0.07549429, 0.1546486, 0.7608347, 0.43421644, 0.8019545, 0.44755372, 0.7997276, 0.44701982,
0.81010026, 0.3866497, 0.8441801, 0.24970922, 0.7982173, 0.4100442, 0.9132067, -0.94733083,
-1.0997784, -0.9421829, -1.1218354, -0.9859438, -1.1612623, -0.96009386, -1.1590697, -0.9456968,
-1.1142067, -0.9900875, -1.2211759, -1.004981, -1.2370956, -1.349351, -1.2184161, -1.1564747,
-1.0476248, -1.3034617, -0.9740715, -1.5131376, -1.0246942, -1.1564014, -1.091238, -1.2773981,
-0.76259595, -1.0244793, -0.9916798, -0.9816827, -1.0407434, -0.94001544, -1.2400658, -1.0058745,
-1.251888, -1.0026754, -1.2247806, -0.99559414, -1.1104892, -0.9950131, -0.93231726, -1.1461066,
-1.1102134, -1.2707901, -1.2258892, -1.2075629, -0.899022, -1.2902625, -0.8440441, -1.3612556,
-1.1327276, -1.0097463, -1.0870252, -1.0208998, -1.1372137, -1.0238695, -1.0300313, -0.9893144,
-1.0387962, -0.9455299, -1.2633826, -0.97857773, -1.2199508, -0.97649026, -1.0467783, -0.9870789,
-0.8867735, -1.2570912, -0.7990466, -1.2643247, -0.89268696, -1.3204725, -0.9196508, -1.3377675,
-1.1563053, -1.4048479, -0.9489901, -1.2825038, -0.8854966, -1.0209885, -1.166144, -0.99754405,
-1.278291, -1.0010624, -1.3216578, -1.0268149, -1.2370203, -0.99041694, -1.1121378, -1.0252388,
-1.2528121, -1.0185167, -0.72908103, -1.2807931, -0.9268043, -1.2740122, -1.0588918, -1.1783062,
-0.89433515, -1.4704434, -0.90606475, -1.1208334, -0.67285204, -1.341852, -0.80200857, -1.016867,
-1.2564906, -0.9801711, -1.1481711, -0.96293676, -1.0831497, -0.969197, -1.1662431, -0.9715335,
-1.3331397, -1.0049394, -1.2574395, -0.9399705, -1.171572, -0.88565385, -1.2087893, -1.1065894,
-1.0714839, -0.9627551, -1.1188276, -0.8515502, -1.2049681, -1.1173695, -1.0619929, -1.066168,
-1.0279324, -1.0882176, -1.129684, -0.9890163, -0.8740333, -1.2120758, -0.56714463, -1.1103767,
-0.86929953, -0.8791485, -0.98886544, -1.2087606, -0.76514137, -1.0997763, -1.0388865, -0.9463707,
-1.1105144, -0.89834666, -1.1851951, -1.1659127, -1.0132934, -1.0602008, -1.014949, -0.9327261,
-1.0910889, -1.1383713, -1.0091913, -0.99213076, -0.8544737, -1.056894, -0.94257253, -1.0971456,
-0.8758079, -1.2477993, -0.35445136, -1.2152452, -0.5471301, -1.086797, -0.73012817, -1.3945714,
-1.0156894, -1.0198442, -1.0294445, -0.9484633, -1.0997083, -0.95065546, -1.1494579, -1.0774312,
-1.0660617, -0.89763457, -1.13983, -0.9865928, -1.1166302, -1.0880268, -0.7381968, -0.9876064,
-0.5964719, -0.9657296, -0.74247324, -1.041322, -0.9059322, -1.2995027, -0.94108796, -0.8961159,
-1.0022087, -0.89709914, -1.0036592, -1.0499129, -1.0242954, -1.0631231, -1.0169288, -1.1581104,
-0.94418347, -0.853006, -1.1137545, -1.183017, -0.9731438, -1.086927, -0.97671837, -1.066008,
-0.48595423, -1.2475185, -0.50115275, -1.326726, -0.5102552, -1.3762127, -0.39939296, -0.9266701,
-0.6510342, -1.1439915, -0.2621194, -1.2735826, -0.9677428, -0.9337987, -1.0829964, -0.8954656,
-1.1583862, -1.0067348, -1.1215614, -1.05432, -1.0779985, -1.151866, -0.98149765, -0.8774674,
-1.1439066, 0.71160585, 0.43664122, 0.63968056, 0.3411116, 0.79933065, 0.6023572, 0.79020524,
0.5203902, 0.63432527, 0.34978527, 0.8055916, 0.5908885, 0.8279619, 0.6594803, 0.9234866,
0.6951297, 0.580612, 0.8534291, 0.61968267, 0.69770944, 0.8167807, 0.6326902, 0.6108708,
0.7726814, 0.5904738, 0.7508015, 0.71711653, 0.7171464, 0.71904653, 0.57166296, 0.70845544,
0.3433037, 0.8610815, 0.6749295, 0.87055725, 0.6884554, 0.70868635, 0.56713784, 0.91778255,
0.71033454, 0.8496836, 0.68372923, 0.9768204, 0.70797944, 0.5078603, 0.86912346, 0.48779017,
0.80497104, 0.66758573, 0.7792437, 0.63723993, 0.8364369, 0.7909154, 0.7067954, 0.74354,
0.72215, 0.7137401, 0.5893581, 0.77508205, 0.4122566, 0.8444451, 0.59620094, 0.6672466,
0.5036563, 0.6805886, 0.72852767, 0.63650995, 0.74002045, 0.6952553, 0.6968493, 0.8008863,
0.631564, 0.7486131, 0.79336673, 0.71474713, 0.6311797, 0.69647217, 0.6505069, 0.8208874,
0.7216524, 0.8688757, 0.6455133, 0.87244576, 0.6376998, 0.94607174, 0.8251329, 0.6735983,
0.51751864, 0.87973493, 0.74826664, 0.8994043, 0.72413105, 0.72747874, 0.808015, 0.6329842,
0.8622399, 0.47823763, 0.8856161, 0.6762785, 0.73437214, 0.3766058, 0.764144, 0.60693324,
0.89371794, 0.92908806, 0.7702812, 0.79492164, 0.58807003, 0.678272, 0.4573259, 0.7444603,
0.49847388, 0.84439206, 0.51984715, 0.9452883, 0.7511028, 0.81281227};
PackNCHWToNHWCFp32(nchw_co, *correct, out_t->Batch(), out_t->Width() * out_t->Height(), out_t->Channel());
conv_param->kernel_h_ = conv_param->kernel_w_ = 3;
conv_param->stride_h_ = conv_param->stride_w_ = 2;
conv_param->dilation_h_ = conv_param->dilation_w_ = 1;
conv_param->pad_h_ = conv_param->pad_w_ = 1;
return out_t->ElementsNum();
}
TEST_F(TestDeConvolutionFp32, DeConvTest1) {
std::vector<lite::tensor::Tensor *> inputs_;
std::vector<lite::tensor::Tensor *> outputs_;
ConvParameter *deconv_param = new ConvParameter();
lite::Context *ctx = new lite::Context();
ctx->threadNum = 1;
float *correct;
int total_size = DeConvTestInit1(&inputs_, &outputs_, deconv_param, &correct);
kernel::DeConvolutionCPUKernel *deconv =
new kernel::DeConvolutionCPUKernel(reinterpret_cast<OpParameter *>(deconv_param), inputs_, outputs_, ctx);
deconv->Init();
deconv->Run();
CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001);
delete deconv_param;
delete deconv;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;
free(correct);
}
int DeConvTestInit2(std::vector<lite::tensor::Tensor *> *inputs_, std::vector<lite::tensor::Tensor *> *outputs_,
ConvParameter *conv_param, float **correct) {
auto *in_t =
new lite::tensor::Tensor(kNumberTypeFloat, {1, 4, 2, 3}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
in_t->MallocData();
float in[] = {7.7566547, 19.250782, 17.923292, 13.584222, 3.3293908, 9.734102, 18.83455, -1.5142503,
-0.29382008, 18.686155, 0.087307654, 4.2010098, -2.2539594, 4.1795673, 13.142356, -3.5939367,
16.505789, 19.899279, 8.556229, 19.969376, -6.2355065, -2.3804698, -9.027744, 9.5542}; /* nhwc */
memcpy(in_t->Data(), in, sizeof(float) * in_t->ElementsNum());
inputs_->push_back(in_t);
auto *weight_t =
new lite::tensor::Tensor(kNumberTypeFloat, {3, 3, 3, 2}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
weight_t->MallocData();
float weight[] = {-0.39557076, 0.15087655, 0.35216075, -0.20893791, 0.28683448, 0.08006268, 0.9830812,
0.27212173, 0.5171944, -0.0014505, 0.78694165, 0.25425306, 0.16605458, -0.06127124,
0.07637237, -0.5596424, -0.26599348, 0.223331, -0.45220536, -0.17021523, 0.20895825,
-0.07697097, 0.17581257, 0.09553282, 0.5369023, -0.6631143, 0.51170826, -0.5332868,
-0.19414032, -0.7109704, -0.05779554, -0.05178713, 0.3592201, -0.05532698, 0.06928781,
-0.5730523, -0.21037689, -0.01435696, 0.33056936, 0.51348346, -0.28136733, -0.36971128,
-0.10048455, 0.09297352, -0.27097073, -0.08646037, -0.06631696, -0.1684566, 0.31797925,
-0.06270258, 0.00119315, -0.2821196, -0.5166795, -0.09961014}; /* nhwc */
memcpy(weight_t->Data(), weight, sizeof(float) * weight_t->ElementsNum());
inputs_->push_back(weight_t);
std::vector<int> out_nhwc_dims = {1, 7, 3, 2};
auto *out_t =
new lite::tensor::Tensor(kNumberTypeFloat, out_nhwc_dims, schema::Format_NHWC, static_cast<schema::NodeType>(1));
out_t->MallocData();
outputs_->push_back(out_t);
*correct = reinterpret_cast<float *>(malloc(out_t->ElementsNum() * sizeof(float))); /* nc4hw4 */
float nchw_co[] = {9.005795, 15.341887, 6.091704, 13.748293, -7.92756, 10.232557, 9.045886,
33.1299, 8.5707, 5.318199, -14.367487, 10.22495, -2.5882099, -0.12742424,
1.195263, 6.469591, 9.609164, 6.112072, 16.333368, -4.87735, -8.439645,
-11.827093, -12.340071, -2.6368382, -14.432123, -8.483799, -12.28651, 0.80561405,
11.332421, -0.43688506, -3.476327, -4.587028, -1.9491882, -3.3619316, -15.831648,
-10.517606, -9.204161, -0.15148449, 1.5822954, -10.122691, -4.7448387, 3.99177};
PackNCHWToNHWCFp32(nchw_co, *correct, out_t->Batch(), out_t->Width() * out_t->Height(), out_t->Channel());
conv_param->kernel_h_ = conv_param->kernel_w_ = 3;
conv_param->stride_h_ = conv_param->stride_w_ = 2;
conv_param->dilation_h_ = conv_param->dilation_w_ = 1;
conv_param->pad_h_ = conv_param->pad_w_ = 1;
return out_t->ElementsNum();
}
TEST_F(TestDeConvolutionFp32, DeConvTest2) {
std::vector<lite::tensor::Tensor *> inputs_;
std::vector<lite::tensor::Tensor *> outputs_;
auto deconv_param = new ConvParameter();
float *correct;
int total_size = DeConvTestInit2(&inputs_, &outputs_, deconv_param, &correct);
lite::Context *ctx = new lite::Context;
ctx->threadNum = 4;
kernel::DeConvolutionCPUKernel *deconv =
new kernel::DeConvolutionCPUKernel(reinterpret_cast<OpParameter *>(deconv_param), inputs_, outputs_, ctx);
deconv->Init();
deconv->Run();
EXPECT_EQ(0, lite::CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size));
delete deconv_param;
delete deconv;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;
free(correct);
}
int DeConvTestInit3(std::vector<lite::tensor::Tensor *> *inputs_, std::vector<lite::tensor::Tensor *> *outputs_,
ConvParameter *conv_param, float **correct) {
std::vector<int> in_dims_nhwc = {1, 3, 3, 2};
auto *in_t =
new lite::tensor::Tensor(kNumberTypeFloat, in_dims_nhwc, schema::Format_NHWC, static_cast<schema::NodeType>(1));
in_t->MallocData();
float in_nchw[] = {0.10411751, 0.24034509, 0.71456534, 0.75286126, 0.9778457, 0.21043599,
0.26498786, 0.6701024, 0.9744634, 0.49075702, 0.03877404, 0.48646277,
0.5473929, 0.32438126, 0.87553847, 0.75820315, 0.86666644, 0.4852329};
PackNCHWToNHWCFp32(in_nchw, reinterpret_cast<float *>(in_t->Data()), in_t->Batch(), in_t->Width() * in_t->Height(),
in_t->Channel());
inputs_->push_back(in_t);
std::vector<int> w_dims_nhwc = {2, 2, 2, 2};
auto *weight_t =
new lite::tensor::Tensor(kNumberTypeFloat, w_dims_nhwc, schema::Format_NHWC, schema::NodeType_Parameter);
weight_t->MallocData();
float w_nchw[] = {-0.108016446, -0.44254777, 0.29249913, 0.18764605, 1.1250675, 0.29441583,
-0.34362152, 0.7557833, 0.16503833, 0.2418737, -0.26612744, 0.5072577,
-0.4284475, 0.2215941, 0.9273913, 0.34634787};
PackNCHWToNHWCFp32(w_nchw, weight_t->Data(), weight_t->Batch(), weight_t->Width() * weight_t->Height(),
weight_t->Channel());
inputs_->push_back(weight_t);
std::vector<int> out_dims_nhwc = {1, 9, 9, 2};
auto *out_t =
new lite::tensor::Tensor(kNumberTypeFloat, out_dims_nhwc, schema::Format_NC4HW4, schema::NodeType_Parameter);
out_t->MallocData();
outputs_->push_back(out_t);
*correct = reinterpret_cast<float *>(malloc(out_t->ElementsNum() * sizeof(float)));
float nchw_co[] = {0.069747314, 0.0, 0.072624244, -0.019562019, 0.0, -0.096985765, 0.0031001933, 0.0, -0.19856673,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
-0.100149624, 0.0, 0.26847753, 0.059981894, 0.0, 0.06476824, 0.07954865, 0.0, 0.38084733,
0.009019416, 0.0, -0.20077711, -0.05208808, 0.0, -0.35428414, 0.12176686, 0.0, 0.11864175,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.074535, 0.0, 0.4189407, 0.19969228, 0.0, 0.3480338, -0.17145246, 0.0, 0.4836111,
0.09650954, 0.0, 0.06611961, 0.0706511, 0.0, -0.08692852, -0.02517605, 0.0, -0.31388155,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
-0.12426994, 0.0, 0.43432832, -0.034639344, 0.0, 0.5653653, 0.15589589, 0.0, 0.42899233,
-0.0931244, 0.0, 0.1394027, 0.2537918, 0.0, 0.0793535, 0.5955104, 0.0, 0.31817663,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.41934675, 0.0, 0.24866292, -0.04662904, 0.0, 0.1950781, 0.2056013, 0.0, 0.7085419,
0.6124906, 0.0, 0.34295332, 0.96116215, 0.0, 0.35977423, -0.1383676, 0.0, 0.25596985,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.24894807, 0.0, 0.7585884, -0.03518048, 0.0, 0.8513882, 0.73965645, 0.0, 0.46228492,
-0.026721025, 0.0, 0.24602996, 0.38258934, 0.0, 0.38933694, 0.88844025, 0.0, 0.3944222,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.6120955, 0.0, 0.46287543, 0.57347727, 0.0, 0.80662024, 0.11515418, 0.0, 0.90454257};
PackNCHWToNHWCFp32(nchw_co, *correct, out_t->Batch(), out_t->Width() * out_t->Height(), out_t->Channel());
conv_param->kernel_h_ = conv_param->kernel_w_ = 2;
conv_param->stride_h_ = conv_param->stride_w_ = 3;
conv_param->dilation_h_ = conv_param->dilation_w_ = 2;
conv_param->pad_h_ = conv_param->pad_w_ = 0;
return out_t->ElementsNum();
}
TEST_F(TestDeConvolutionFp32, DeConvTest3) {
std::vector<lite::tensor::Tensor *> inputs_;
std::vector<lite::tensor::Tensor *> outputs_;
auto deconv_param = new ConvParameter();
float *correct;
int total_size = DeConvTestInit3(&inputs_, &outputs_, deconv_param, &correct);
lite::Context *ctx = new lite::Context;
ctx->threadNum = 2;
kernel::DeConvolutionCPUKernel *deconv =
new kernel::DeConvolutionCPUKernel(reinterpret_cast<OpParameter *>(deconv_param), inputs_, outputs_, ctx);
deconv->Init();
deconv->Run();
CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001);
delete deconv_param;
delete deconv;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;
free(correct);
}
int DeConvTestInit4(std::vector<lite::tensor::Tensor *> *inputs_, std::vector<lite::tensor::Tensor *> *outputs_,
ConvParameter *conv_param, float **correct) {
size_t buffer_size;
std::vector<int> in_nhwc_dims = {1, 300, 300, 30};
auto *in_t =
new lite::tensor::Tensor(kNumberTypeFloat, in_nhwc_dims, schema::Format_NHWC, static_cast<schema::NodeType>(1));
in_t->MallocData();
std::string in_nhwc_path = "./deconv/deconv_fp32_nhwc_input1.bin";
auto in_nhwc = reinterpret_cast<float *>(mindspore::lite::ReadFile(in_nhwc_path.c_str(), &buffer_size));
memcpy(in_t->Data(), in_nhwc, buffer_size);
inputs_->push_back(in_t);
std::vector<int> w_nhwc_dims = {30, 3, 3, 40};
auto *weight_t =
new lite::tensor::Tensor(kNumberTypeFloat, w_nhwc_dims, schema::Format_NHWC, static_cast<schema::NodeType>(1));
weight_t->MallocData();
std::string weight_path = "./deconv/deconv_fp32_nchw_weight1.bin";
auto weight_nchw = reinterpret_cast<float *>(mindspore::lite::ReadFile(weight_path.c_str(), &buffer_size));
PackNCHWToNHWCFp32(weight_nchw, weight_t->Data(), weight_t->Batch(), weight_t->Width() * weight_t->Height(),
weight_t->Channel());
inputs_->push_back(weight_t);
auto *bias_t =
new lite::tensor::Tensor(kNumberTypeFloat, {40}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
bias_t->MallocData();
std::string bias_path = "./deconv/deconv_fp32_nchw_bias1.bin";
auto bias = mindspore::lite::ReadFile(bias_path.c_str(), &buffer_size);
memcpy(bias_t->Data(), bias, buffer_size);
inputs_->push_back(bias_t);
std::vector<int> out_nhwc_dims = {1, 302, 302, 40};
auto *out_t =
new lite::tensor::Tensor(kNumberTypeFloat, out_nhwc_dims, schema::Format_NHWC, static_cast<schema::NodeType>(1));
out_t->MallocData();
outputs_->push_back(out_t);
std::string out_path = "./deconv/deconv_fp32_nchw_output1.bin";
auto out_nchw = mindspore::lite::ReadFile(out_path.c_str(), &buffer_size);
*correct = reinterpret_cast<float *>(malloc(buffer_size));
PackNCHWToNHWCFp32(out_nchw, *correct, out_t->Batch(), out_t->Width() * out_t->Height(), out_t->Channel());
conv_param->kernel_h_ = conv_param->kernel_w_ = 3;
conv_param->stride_h_ = conv_param->stride_w_ = 1;
conv_param->dilation_h_ = conv_param->dilation_w_ = 1;
conv_param->pad_h_ = conv_param->pad_w_ = 0;
conv_param->is_relu_ = conv_param->is_relu6_ = false;
return out_t->ElementsNum();
}
TEST_F(TestDeConvolutionFp32, DeConvTest4) {
std::vector<lite::tensor::Tensor *> inputs_;
std::vector<lite::tensor::Tensor *> outputs_;
auto deconv_param = new ConvParameter();
float *correct;
int total_size = DeConvTestInit4(&inputs_, &outputs_, deconv_param, &correct);
lite::Context *ctx = new lite::Context;
ctx->threadNum = 2;
kernel::DeConvolutionCPUKernel *deconv =
new kernel::DeConvolutionCPUKernel(reinterpret_cast<OpParameter *>(deconv_param), inputs_, outputs_, ctx);
deconv->Init();
deconv->Run();
CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001);
/* running warm up */
for (int i = 0; i < 0; i++) {
deconv->Run();
}
/* running time cost */
int loop_count = 1;
auto time_start = mindspore::lite::GetTimeUs();
for (int i = 0; i < loop_count; i++) {
deconv->Run();
}
auto time_end = mindspore::lite::GetTimeUs();
auto cost = time_end - time_start;
uint64_t time_avg = cost / loop_count;
printf("deconv fp32 average time : %f ms\n", time_avg / 1000.0f);
delete deconv_param;
delete deconv;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;
free(correct);
}
} // namespace mindspore

View File

@ -0,0 +1,145 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <sys/time.h>
#include <iostream>
#include <memory>
#include "utils/log_adapter.h"
#include "common/common_test.h"
#include "src/common/file_utils.h"
#include "src/runtime/kernel/arm/fp32/fullconnection.h"
#include "src/runtime/kernel/arm/opclib/fp32/matmul.h"
namespace mindspore {
using mindspore::lite::tensor::Tensor;
class TestFcFp32 : public mindspore::Common {
public:
TestFcFp32() {}
};
int FcTestInit1(std::vector<lite::tensor::Tensor *> *inputs_, std::vector<lite::tensor::Tensor *> *outputs_,
MatMulParameter *matmal_param, float **correct) {
Tensor *in_t = new Tensor(kNumberTypeFloat, {2, 2, 2, 2}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
in_t->MallocData();
float in[] = {-3.2366564, -4.7733846, -7.8329225, 16.146885, 5.060793, -6.1471, -1.7680453, -6.5721383,
17.87506, -5.1192183, 10.742863, 1.4536934, 19.693445, 19.45783, 5.063163, 0.5234792};
memcpy(in_t->Data(), in, sizeof(float) * in_t->ElementsNum());
inputs_->push_back(in_t);
Tensor *weight_t = new Tensor(kNumberTypeFloat, {3, 8}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
weight_t->MallocData();
float weight[] = {-0.0024438887, 0.0006738146, -0.008169129, 0.0021510671, -0.012470592, -0.0053063435,
0.006050155, 0.008656233, 0.012911413, -0.0028635843, -0.00034080597, -0.0010622552,
-0.012254699, -0.01312836, 0.0025241964, -0.004706142, 0.002451482, -0.009558459,
0.004481974, 0.0033251503, -0.011705584, -0.001720293, -0.0039410214, -0.0073637343};
memcpy(weight_t->Data(), weight, sizeof(float) * weight_t->ElementsNum());
inputs_->push_back(weight_t);
Tensor *bias_t = new Tensor(kNumberTypeFloat, {3}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
bias_t->MallocData();
float bias[] = {1.6103756, -0.9872417, 0.546849};
memcpy(bias_t->Data(), bias, sizeof(float) * bias_t->ElementsNum());
inputs_->push_back(bias_t);
Tensor *out_t = new Tensor(kNumberTypeFloat, {2, 3}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
out_t->MallocData();
outputs_->push_back(out_t);
*correct = reinterpret_cast<float *>(malloc(out_t->ElementsNum() * sizeof(float)));
float nchw_co[] = {1.6157111, -0.98469573, 0.6098231, 1.1649342, -1.2334653, 0.404779};
memcpy(*correct, nchw_co, out_t->ElementsNum() * sizeof(float));
matmal_param->b_transpose_ = true;
matmal_param->a_transpose_ = false;
matmal_param->has_bias_ = true;
matmal_param->act_type_ = ActType_No;
return out_t->ElementsNum();
}
TEST_F(TestFcFp32, FcTest1) {
std::vector<lite::tensor::Tensor *> inputs_;
std::vector<lite::tensor::Tensor *> outputs_;
auto matmul_param = new MatMulParameter();
float *correct;
int total_size = FcTestInit1(&inputs_, &outputs_, matmul_param, &correct);
lite::Context *ctx = new lite::Context;
ctx->threadNum = 2;
kernel::FullconnectionCPUKernel *fc =
new kernel::FullconnectionCPUKernel(reinterpret_cast<OpParameter *>(matmul_param), inputs_, outputs_, ctx);
fc->Init();
fc->Run();
CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001);
}
int FcTestInit2(std::vector<lite::tensor::Tensor *> *inputs_, std::vector<lite::tensor::Tensor *> *outputs_,
MatMulParameter *matmal_param, float **correct) {
size_t buffer_size;
Tensor *in_t = new Tensor(kNumberTypeFloat, {20, 4, 2, 10}, schema::Format_NCHW, static_cast<schema::NodeType>(1));
in_t->MallocData();
std::string in_path = "./matmul/FcFp32_input1.bin";
auto in_data = mindspore::lite::ReadFile(in_path.c_str(), &buffer_size);
memcpy(in_t->Data(), in_data, buffer_size);
inputs_->push_back(in_t);
Tensor *weight_t = new Tensor(kNumberTypeFloat, {30, 80}, schema::Format_NCHW, static_cast<schema::NodeType>(1));
weight_t->MallocData();
std::string weight_path = "./matmul/FcFp32_weight1.bin";
auto w_data = mindspore::lite::ReadFile(weight_path.c_str(), &buffer_size);
memcpy(weight_t->Data(), w_data, buffer_size);
inputs_->push_back(weight_t);
Tensor *bias_t = new Tensor(kNumberTypeFloat, {30}, schema::Format_NCHW, static_cast<schema::NodeType>(1));
bias_t->MallocData();
std::string bias_path = "./matmul/FcFp32_bias1.bin";
auto bias_data = mindspore::lite::ReadFile(bias_path.c_str(), &buffer_size);
memcpy(bias_t->Data(), bias_data, buffer_size);
inputs_->push_back(bias_t);
Tensor *out_t = new Tensor(kNumberTypeFloat, {20, 30}, schema::Format_NCHW, static_cast<schema::NodeType>(1));
out_t->MallocData();
outputs_->push_back(out_t);
*correct = reinterpret_cast<float *>(malloc(out_t->ElementsNum() * sizeof(float)));
std::string out_path = "./matmul/FcFp32_output1.bin";
auto out_data = mindspore::lite::ReadFile(out_path.c_str(), &buffer_size);
memcpy(*correct, out_data, out_t->ElementsNum() * sizeof(float));
matmal_param->b_transpose_ = true;
matmal_param->a_transpose_ = false;
matmal_param->has_bias_ = true;
matmal_param->act_type_ = ActType_No;
return out_t->ElementsNum();
}
TEST_F(TestFcFp32, FcTest2) {
std::vector<lite::tensor::Tensor *> inputs_;
std::vector<lite::tensor::Tensor *> outputs_;
auto matmul_param = new MatMulParameter();
float *correct;
int total_size = FcTestInit2(&inputs_, &outputs_, matmul_param, &correct);
lite::Context *ctx = new lite::Context;
ctx->threadNum = 1;
kernel::FullconnectionCPUKernel *fc =
new kernel::FullconnectionCPUKernel(reinterpret_cast<OpParameter *>(matmul_param), inputs_, outputs_, ctx);
fc->Init();
fc->Run();
CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001);
}
} // namespace mindspore

View File

@ -0,0 +1,369 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include "utils/log_adapter.h"
#include "common/common_test.h"
#include "src/common/file_utils.h"
#include "mindspore/lite/src/runtime/kernel/arm/opclib/pack.h"
#include "mindspore/lite/src/runtime/kernel/arm/opclib/fp32/strassen_matmul.h"
#include "mindspore/lite/src/runtime/kernel/arm/opclib/conv_parameter.h"
namespace mindspore {
class TestStrassenFp32 : public mindspore::Common {
public:
TestStrassenFp32() {}
};
TEST_F(TestStrassenFp32, MatrixAdd1) {
float a[] = {0.06796285, 0.6176181, 0.33195993, 0.2752791, 0.36864007, 0.04605605, 0.33899087, 0.9820137,
0.49804246, 0.8242412, 0.8458231, 0.6530539, 0.6336898, 0.8367749, 0.57166654, 0.25895607,
0.90079665, 0.10585558, 0.8215811, 0.48977906, 0.7895138, 0.41816455, 0.18999523, 0.28736928,
0.5882977, 0.44262612, 0.65245426, 0.7834421, 0.60903394, 0.82289135, 0.03855767, 0.30543327,
0.37747085, 0, 0, 0, 0.590335, 0, 0, 0,
0.7578682, 0, 0, 0, 0.81001425, 0, 0, 0,
0.9487712, 0, 0, 0, 0.11742989, 0, 0, 0,
0.60004807, 0, 0, 0, 0.05973052, 0, 0, 0};
float b[] = {0.112120815, 0.6869974, 0.08290442, 0.43003577, 0.044390075, 0.23077105, 0.23964432, 0.4426781,
0.6612115, 0.14988606, 0.84881437, 0.032587975, 0.35028255, 0.41838303, 0.12859282, 0.060378596,
0.8272769, 0.6949804, 0.9120368, 0.12399232, 0.9292184, 0.7566025, 0.10235854, 0.015936268,
0.20426726, 0.9926392, 0.54714125, 0.7022856, 0.58746314, 0.95714045, 0.26433542, 0.9030878,
0.8596953, 0, 0, 0, 0.8341476, 0, 0, 0,
0.72301114, 0, 0, 0, 0.40733734, 0, 0, 0,
0.2873559, 0, 0, 0, 0.612321, 0, 0, 0,
0.5008707, 0, 0, 0, 0.2586266, 0, 0, 0};
float add[] = {0.18008366, 1.3046155, 0.41486436, 0.7053149, 0.41303015, 0.2768271, 0.5786352, 1.4246918,
1.159254, 0.9741273, 1.6946375, 0.6856419, 0.9839724, 1.255158, 0.7002593, 0.3193347,
1.7280736, 0.80083597, 1.7336179, 0.6137714, 1.7187322, 1.174767, 0.29235378, 0.30330554,
0.792565, 1.4352653, 1.1995955, 1.4857277, 1.1964971, 1.7800318, 0.3028931, 1.2085211,
1.2371662, 0, 0, 0, 1.4244826, 0, 0, 0,
1.4808793, 0, 0, 0, 1.2173516, 0, 0, 0,
1.2361271, 0, 0, 0, 0.72975093, 0, 0, 0,
1.1009188, 0, 0, 0, 0.31835714, 0, 0, 0};
float out[64] = {0};
MatrixAdd(a, b, out, 32, 32, 32, 8, 2);
EXPECT_EQ(0, lite::CompareOutputData(out, add, 64));
}
TEST_F(TestStrassenFp32, MatrixAdd2) {
float a[] = {0.06796285, 0.6176181, 0.33195993, 0.2752791, 0.36864007, 0.04605605, 0.33899087, 0.9820137,
0.49804246, 0.8242412, 0.8458231, 0.6530539, 0.6336898, 0.8367749, 0.57166654, 0.25895607,
0.90079665, 0.10585558, 0.8215811, 0.48977906, 0.7895138, 0.41816455, 0.18999523, 0.28736928,
0.5882977, 0.44262612, 0.65245426, 0.7834421, 0.60903394, 0.82289135, 0.03855767, 0.30543327,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0.37747085, 0, 0, 0,
0.590335, 0, 0, 0, 0.7578682, 0, 0, 0,
0.81001425, 0, 0, 0, 0.9487712, 0, 0, 0,
0.11742989, 0, 0, 0, 0.60004807, 0, 0, 0,
0.05973052, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0};
float b[] = {0.112120815, 0.6869974, 0.08290442, 0.43003577, 0.044390075, 0.23077105, 0.23964432, 0.4426781,
0.6612115, 0.14988606, 0.84881437, 0.032587975, 0.35028255, 0.41838303, 0.12859282, 0.060378596,
0.8272769, 0.6949804, 0.9120368, 0.12399232, 0.9292184, 0.7566025, 0.10235854, 0.015936268,
0.20426726, 0.9926392, 0.54714125, 0.7022856, 0.58746314, 0.95714045, 0.26433542, 0.9030878,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0.8596953, 0, 0, 0, 0.8341476, 0, 0, 0,
0.72301114, 0, 0, 0, 0.40733734, 0, 0, 0,
0.2873559, 0, 0, 0, 0.612321, 0, 0, 0,
0.5008707, 0, 0, 0, 0.2586266, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0};
float add[] = {0.18008366, 1.3046155, 0.41486436, 0.7053149, 0.41303015, 0.2768271, 0.5786352, 1.4246918,
1.159254, 0.9741273, 1.6946375, 0.6856419, 0.9839724, 1.255158, 0.7002593, 0.3193347,
1.7280736, 0.80083597, 1.7336179, 0.6137714, 1.7187322, 1.174767, 0.29235378, 0.30330554,
0.792565, 1.4352653, 1.1995955, 1.4857277, 1.1964971, 1.7800318, 0.3028931, 1.2085211,
0, 0, 0, 0, 1.2371662, 0, 0, 0,
1.4244826, 0, 0, 0, 1.4808793, 0, 0, 0,
1.2173516, 0, 0, 0, 1.2361271, 0, 0, 0,
0.72975093, 0, 0, 0, 1.1009188, 0, 0, 0,
0.31835714, 0, 0, 0, 0, 0, 0, 0};
float out[72] = {0};
MatrixAdd(a, b, out, 44, 56, 36, 8, 2);
EXPECT_EQ(0, lite::CompareOutputData(out, add, 72));
}
TEST_F(TestStrassenFp32, MatrixSub1) {
float a[] = {0.4160896, 0.55011475, 0.60395557, 0.964036, 0.8010256, 0.908257, 0.60170764, 0.008877548,
0.4973592, 0.6104505, 0.2957374, 0.39589414, 0.0151615525, 0.45663023, 0.3815148, 0.6419536,
0.9118046, 0.5312479, 0.104496025, 0.5972911, 0.9671534, 0.7195669, 0.23360363, 0.22078007,
0.31118092, 0.7438336, 0.5592656, 0.7212792, 0.97856164, 0.26012093, 0.18205991, 0.90656054,
0.24593723, 0, 0, 0, 0.5024593, 0, 0, 0,
0.42271087, 0, 0, 0, 0.48668534, 0, 0, 0,
0.4374295, 0, 0, 0, 0.22822042, 0, 0, 0,
0.88180095, 0, 0, 0, 0.7505223, 0, 0, 0};
float b[] = {0.14911577, 0.63214976, 0.74834836, 0.36854064, 0.5801671, 0.24166176, 0.64528674, 0.04887214,
0.23637155, 0.34321627, 0.69035923, 0.6114065, 0.73006815, 0.575073, 0.88130534, 0.72951907,
0.17092401, 0.652334, 0.6288812, 0.62121505, 0.12793411, 0.16503152, 0.7564361, 0.51976234,
0.19353953, 0.5795124, 0.6671185, 0.10646773, 0.13608798, 0.37959677, 0.24294423, 0.1790138,
0.85054415, 0, 0, 0, 0.18541782, 0, 0, 0,
0.72714496, 0, 0, 0, 0.43221787, 0, 0, 0,
0.7200413, 0, 0, 0, 0.15780604, 0, 0, 0,
0.30473796, 0, 0, 0, 0.37719592, 0, 0, 0};
float s[] = {0.26697382, -0.082035, -0.14439279, 0.59549534, 0.22085851, 0.6665952, -0.0435791, -0.03999459,
0.26098764, 0.26723424, -0.39462185, -0.21551237, -0.7149066, -0.11844277, -0.49979055, -0.08756548,
0.7408806, -0.12108606, -0.5243852, -0.02392393, 0.8392193, 0.5545354, -0.5228325, -0.29898226,
0.11764139, 0.16432118, -0.10785288, 0.6148115, 0.8424736, -0.11947584, -0.06088431, 0.72754675,
-0.6046069, 0., 0., 0., 0.31704146, 0., 0., 0.,
-0.3044341, 0., 0., 0., 0.05446747, 0., 0., 0.,
-0.2826118, 0., 0., 0., 0.07041438, 0., 0., 0.,
0.57706296, 0., 0., 0., 0.3733264, 0., 0., 0.};
float out[64] = {0};
MatrixSub(a, b, out, 32, 32, 32, 8, 2);
EXPECT_EQ(0, lite::CompareOutputData(out, s, 64));
}
TEST_F(TestStrassenFp32, MatrixSub2) {
float a[] = {0.4160896, 0.55011475, 0.60395557, 0.964036, 0.8010256, 0.908257, 0.60170764, 0.008877548,
0.4973592, 0.6104505, 0.2957374, 0.39589414, 0.0151615525, 0.45663023, 0.3815148, 0.6419536,
0.9118046, 0.5312479, 0.104496025, 0.5972911, 0.9671534, 0.7195669, 0.23360363, 0.22078007,
0.31118092, 0.7438336, 0.5592656, 0.7212792, 0.97856164, 0.26012093, 0.18205991, 0.90656054,
0.24593723, 0, 0, 0, 0.5024593, 0, 0, 0,
0.42271087, 0, 0, 0, 0.48668534, 0, 0, 0,
0.4374295, 0, 0, 0, 0.22822042, 0, 0, 0,
0.88180095, 0, 0, 0, 0.7505223, 0, 0, 0};
float b[] = {0.14911577, 0.63214976, 0.74834836, 0.36854064, 0.5801671, 0.24166176, 0.64528674, 0.04887214,
0.23637155, 0.34321627, 0.69035923, 0.6114065, 0.73006815, 0.575073, 0.88130534, 0.72951907,
0.17092401, 0.652334, 0.6288812, 0.62121505, 0.12793411, 0.16503152, 0.7564361, 0.51976234,
0.19353953, 0.5795124, 0.6671185, 0.10646773, 0.13608798, 0.37959677, 0.24294423, 0.1790138,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0.85054415, 0, 0, 0,
0.18541782, 0, 0, 0, 0.72714496, 0, 0, 0,
0.43221787, 0, 0, 0, 0.7200413, 0, 0, 0,
0.15780604, 0, 0, 0, 0.30473796, 0, 0, 0,
0.37719592, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0};
float s[] = {0.26697382, -0.082035, -0.14439279, 0.59549534, 0.22085851, 0.6665952, -0.0435791, -0.03999459,
0.26098764, 0.26723424, -0.39462185, -0.21551237, -0.7149066, -0.11844277, -0.49979055, -0.08756548,
0.7408806, -0.12108606, -0.5243852, -0.02392393, 0.8392193, 0.5545354, -0.5228325, -0.29898226,
0.11764139, 0.16432118, -0.10785288, 0.6148115, 0.8424736, -0.11947584, -0.06088431, 0.72754675,
0, 0, 0, 0, -0.6046069, 0., 0., 0.,
0.31704146, 0., 0., 0., -0.3044341, 0., 0., 0.,
0.05446747, 0., 0., 0., -0.2826118, 0., 0., 0.,
0.07041438, 0., 0., 0., 0.57706296, 0., 0., 0.,
0.3733264, 0., 0., 0, 0, 0, 0, 0.};
float out[72] = {0};
MatrixSub(a, b, out, 32, 44, 36, 8, 2);
EXPECT_EQ(0, lite::CompareOutputData(out, s, 72));
}
TEST_F(TestStrassenFp32, MatrixPack1) {
float in[] = {4.1795, 13.142, -3.593, 16.505, 19.969, -6.235, -2.380, -9.027, 23.622, 8.3608, 47.325, -14.36,
-0.784, 37.925, -0.081, 6.1298, 37.998, 13.719, 11.029, 1.7127, 9.0560, 14.988, 3.1866, 0.0562,
14.530, -14.10, -8.115, -8.071, 19.250, 17.923, 13.584, 3.3293, -1.514, -0.293, 18.686, 0.0873,
19.899, 8.5562, 0.0, 0.0, 9.5542, 18.974, 0.0, 0.0, 15.370, 4.3049, 0.0, 0.0,
0.6721, -1.517, 0.0, 0.0, -1.770, 41.903, 0.0, 0.0, 8.1381, 9.1391, 0.0, 0.0,
-8.158, 7.7566, 0.0, 0.0, 9.7341, 18.834, 0.0, 0.0, 4.2010, -2.253, 0.0, 0.0};
float correct[] = {4.1795, 13.142, -3.593, 16.505, 19.969, -6.235, -2.380, -9.027, 23.622, 8.3608, 47.325, -14.36,
-0.784, 37.925, -0.081, 6.1298, 19.899, 8.5562, 0.0, 0.0, 9.5542, 18.974, 0.0, 0.0,
15.370, 4.3049, 0.0, 0.0, 0.6721, -1.517, 0.0, 0.0, 37.998, 13.719, 11.029, 1.7127,
9.0560, 14.988, 3.1866, 0.0562, 14.530, -14.10, -8.115, -8.071, -1.770, 41.903, 0.0, 0.0,
8.1381, 9.1391, 0.0, 0.0, -8.158, 7.7566, 0.0, 0.0};
float out[56] = {0};
MatrixPack(in, out, 7, 2, 36);
EXPECT_EQ(0, lite::CompareOutputData(out, correct, 56));
}
TEST_F(TestStrassenFp32, MatrixPack2) {
float in[] = {4.1795, 13.142, -3.593, 16.505, 19.969, -6.235, -2.380, -9.027, 23.622, 8.3608, 47.325, -14.36,
-0.784, 37.925, -0.081, 6.1298, 37.998, 13.719, 11.029, 1.7127, 9.0560, 14.988, 3.1866, 0.0562,
14.530, -14.10, -8.115, -8.071, 19.250, 17.923, 13.584, 3.3293, -1.514, -0.293, 18.686, 0.0873,
19.899, 8.5562, 0.0, 0.0, 9.5542, 18.974, 0.0, 0.0, 15.370, 4.3049, 0.0, 0.0,
0.6721, -1.517, 0.0, 0.0, -1.770, 41.903, 0.0, 0.0, 8.1381, 9.1391, 0.0, 0.0,
-8.158, 7.7566, 0.0, 0.0, 9.7341, 18.834, 0.0, 0.0, 4.2010, -2.253, 0.0, 0.0};
float correct[] = {4.1795, 13.142, -3.593, 16.505, 19.969, -6.235, -2.380, -9.027, 23.622, 8.3608, 47.325, -14.36,
-0.784, 37.925, -0.081, 6.1298, 19.899, 8.5562, 0.0, 0.0, 9.5542, 18.974, 0.0, 0.0,
15.370, 4.3049, 0.0, 0.0, 0.6721, -1.517, 0.0, 0.0, 37.998, 13.719, 11.029, 1.7127,
9.0560, 14.988, 3.1866, 0.0562, 14.530, -14.10, -8.115, -8.071, 19.250, 17.923, 13.584, 3.3293,
-1.770, 41.903, 0.0, 0.0, 8.1381, 9.1391, 0.0, 0.0, -8.158, 7.7566, 0.0, 0.0,
9.7341, 18.834, 0.0, 0.0, -1.514, -0.293, 18.686, 0.0873, 4.2010, -2.253, 0.0, 0.0};
float out[72] = {0};
MatrixPack(in, out, 9, 2, 36);
EXPECT_EQ(0, lite::CompareOutputData(out, correct, 72));
}
TEST_F(TestStrassenFp32, CommonMatmul1) {
float a_ptr[] = {7.756654, 19.250782, 17.923292, 0, 13.584222, 3.3293908, 9.734102, 0,
18.83455, -1.51425, -0.29382, 0, 18.686155, 0.0873076, 4.2010098, 0,
-2.2539594, 4.1795673, 13.14235, 0, -3.59393, 16.50578, 19.899279, 0,
8.556229, 19.969376, -6.2355065, 0, -2.380469, -9.027744, 9.5542, 0};
float b_ptr[] = {0.2674241, 0.089372, -0.081915, 2.0580146, -0.295045, 1.377944, 0.703658, 1.055378,
1.204049, -0.256505, -0.309640, 0.560465, 0, 0, 0, 0,
0.646906, 0, 0, 0, -0.168206, 0, 0, 0,
-0.95630, 0, 0, 0, 0, 0, 0, 0};
float correct[] = {17.97499, 22.622334, 7.360805, 46.325558, 14.37076, 3.304931, -1.784072, 36.925926,
5.129812, -0.3278886, -2.517368, 36.99899, 10.029593, 0.7127603, -2.77004, 40.90305,
13.988123, 2.186689, -0.943787, 7.138184, 18.128653, 17.31859, 5.7472067, 21.176342,
-11.11159, 29.880829, 15.281498, 35.1893, 13.530734, -15.10318, -9.11581, -9.071925,
-15.36046, 0, 0, 0, -1.081104, 0, 0, 0,
12.719885, 0, 0, 0, 8.056052, 0, 0, 0,
-14.72927, 0, 0, 0, -24.1311, 0, 0, 0,
8.139168, 0, 0, 0, -9.158176, 0, 0, 0};
StrassenMatMulParameter *matmul_param = new StrassenMatMulParameter();
matmul_param->row_ = 8;
matmul_param->deep_ = 1;
matmul_param->col_ = 2;
matmul_param->a_stride_ = 32;
matmul_param->b_stride_ = 16;
matmul_param->c_stride_ = 32;
float c_ptr[64] = {0};
float tmp_ptr[32];
CommonMatMul(a_ptr, b_ptr, c_ptr, matmul_param, tmp_ptr);
EXPECT_EQ(0, lite::CompareOutputData(c_ptr, correct, 64));
delete matmul_param;
}
TEST_F(TestStrassenFp32, CommonMatmul2) {
StrassenMatMulParameter *matmul_param = new StrassenMatMulParameter();
float a[] = {4.864725, 6.830073, 0.76780415, 8.922394, 5.096872, 2.4946148, 4.2148714, 1.7762588, 0.89195687,
9.703938, 2.0654619, 9.048538, 2.358036, 5.643526, 2.5152204, 3.512572, 3.7913973, 3.7136157,
8.820186, 1.5324963, 3.135459, 7.5792265, 7.1820426, 0.267987, 8.737802, 4.064117, 2.7232447,
0.27355433, 0, 0, 0, 0, 0, 0, 0, 0,
6.320409, 9.479354, 0, 0, 1.6220464, 0.57753897, 0, 0, 9.786372,
6.0404425, 0, 0, 2.1067812, 4.8034563, 0, 0, 2.1140356, 8.204062,
0, 0, 3.29985, 1.2034118, 0, 0, 7.6059656, 4.162436, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0};
float b[] = {
4.4558744, 0.6383263, 0.05037839, 9.730914, 8.1542015, 4.3625517, 8.654026, 3.805875, 9.845131, 4.08051,
9.667656, 7.73955, 9.283867, 8.465257, 2.292051, 9.853942, 0.13320169, 3.8789113, 9.460265, 4.2616735,
0.23831692, 4.420147, 0.5355651, 7.829217, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1.9866786, 0, 0, 0, 6.0188327, 0,
0, 0, 6.6249146, 0, 0, 0, 3.5639563, 0, 0, 0,
0.14810833, 0, 0, 0, 7.4168983, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0};
float c[] = {170.86482, 177.98166, 152.0957, 268.3473, 101.39282, 55.216248, 82.31873, 120.65008, 190.18558,
192.58974, 220.54767, 239.75931, 115.32386, 95.52758, 103.82857, 145.08948, 150.4757, 112.04814,
145.50496, 207.63342, 149.6962, 84.76027, 167.65851, 141.06763, 103.42963, 84.63687, 136.74927,
189.26935, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 158.90288, 0, 0, 0, 63.917973,
0, 0, 0, 152.3613, 0, 0, 0, 103.77265, 0,
0, 0, 154.94044, 0, 0, 0, 109.79707, 0, 0,
0, 92.83551, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0};
matmul_param->row_ = 7;
matmul_param->deep_ = 2;
matmul_param->col_ = 2;
matmul_param->a_stride_ = 36;
matmul_param->b_stride_ = 64;
matmul_param->c_stride_ = 40;
float out[80] = {0};
float tmp_ptr[1000];
CommonMatMul(a, b, out, matmul_param, tmp_ptr);
EXPECT_EQ(0, lite::CompareOutputData(out, c, 80));
delete (matmul_param);
}
TEST_F(TestStrassenFp32, RecMatmul1) {
StrassenMatMulParameter *matmul_param = new StrassenMatMulParameter();
matmul_param->row_ = 4;
matmul_param->deep_ = 2;
matmul_param->col_ = 2;
matmul_param->a_stride_ = 16;
matmul_param->b_stride_ = 32;
matmul_param->c_stride_ = 16;
float a[] = {9.02165, 8.657163, 0.56371903, 0.7272156, 1.6258951, 9.919627, 7.47593, 3.5311592,
8.958062, 0.55338514, 9.611276, 7.429841, 8.23804, 3.7503464, 1.2829816, 6.4470887,
4.303486, 6.282502, 0, 0, 9.4194765, 7.8199654, 0, 0,
6.738705, 7.5398073, 0, 0, 0.47684374, 0.87746763, 0, 0};
float b[] = {1.8100919, 6.016964, 5.733568, 5.768448, 2.2823029, 2.173359, 0.56861514, 7.134393,
0.26377398, 3.9010656, 4.868408, 0.33401546, 1.7973539, 8.21896, 5.62239, 8.54786,
0.97356945, 1.0714527, 6.447588, 6.161091, 3.332229, 2.8775468, 6.558747, 2.6986659,
0, 0, 0, 0, 0, 0, 0, 0,
1.9830805, 0, 0, 0, 8.44718, 0, 0, 0,
9.360418, 0, 0, 0, 6.220693, 0, 0, 0,
1.8369701, 0, 0, 0, 4.3965054, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0};
float c[] = {62.668518, 103.9633, 132.43439, 163.67749, 69.12974, 122.12326, 183.23413, 191.96806,
65.052124, 182.57918, 233.14148, 184.20694, 38.785316, 118.74806, 100.689575, 135.12036,
136.34613, 0, 0, 0, 230.64507, 0, 0, 0,
204.15103, 0, 0, 0, 104.86488, 0, 0, 0};
float out[32] = {0};
float tmp_ptr[1000];
RecursionMatmul(a, b, out, matmul_param, 1, 0, tmp_ptr);
EXPECT_EQ(0, lite::CompareOutputData(out, c, 32));
delete (matmul_param);
}
TEST_F(TestStrassenFp32, RecMatmul2) {
StrassenMatMulParameter *matmul_param = new StrassenMatMulParameter();
matmul_param->row_ = 4;
matmul_param->deep_ = 2;
matmul_param->col_ = 2;
matmul_param->a_stride_ = 32;
matmul_param->b_stride_ = 64;
matmul_param->c_stride_ = 32;
float a[] = {9.02165, 8.657163, 0.56371903, 0.7272156, 1.6258951, 9.919627, 7.47593, 3.5311592,
8.958062, 0.55338514, 9.611276, 7.429841, 8.23804, 3.7503464, 1.2829816, 6.4470887,
1, 2, 3, 4, 1, 2, 3, 4,
3, 2, 3, 4, 4, 2, 3, 4,
4.303486, 6.282502, 0, 0, 9.4194765, 7.8199654, 0, 0,
6.738705, 7.5398073, 0, 0, 0.47684374, 0.87746763, 0, 0,
1, 2, 3, 4, 1, 2, 3, 4,
3, 2, 3, 4, 4, 2, 3, 4};
float b[] = {
1.8100919, 6.016964, 5.733568, 5.768448, 2.2823029, 2.173359, 0.56861514, 7.134393, 0.26377398, 3.9010656,
4.868408, 0.33401546, 1.7973539, 8.21896, 5.62239, 8.54786, 0.97356945, 1.0714527, 6.447588, 6.161091,
3.332229, 2.8775468, 6.558747, 2.6986659, 0, 0, 0, 0, 0, 0,
0, 0, 11, 2, 3, 4, 22, 2, 3, 4,
33, 3, 3, 4, 44, 2, 3, 4, 11, 2,
3, 4, 22, 2, 3, 4, 33, 3, 3, 4,
44, 2, 3, 4, 1.9830805, 0, 0, 0, 8.44718, 0,
0, 0, 9.360418, 0, 0, 0, 6.220693, 0, 0, 0,
1.8369701, 0, 0, 0, 4.3965054, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 11, 2, 3, 4,
22, 2, 3, 4, 33, 3, 3, 4, 44, 2,
3, 4, 11, 2, 3, 4, 22, 2, 3, 4,
33, 3, 3, 4, 44, 2, 3, 4};
float c[] = {62.668518, 103.9633, 132.43439, 163.67749, 69.12974, 122.12326, 183.23413, 191.96806,
65.052124, 182.57918, 233.14148, 184.20694, 38.785316, 118.74806, 100.689575, 135.12036,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
136.34613, 0, 0, 0, 230.64507, 0, 0, 0,
204.15103, 0, 0, 0, 104.86488, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0};
float out[64] = {0};
float tmp_ptr[1000];
RecursionMatmul(a, b, out, matmul_param, 1, 0, tmp_ptr);
EXPECT_EQ(0, lite::CompareOutputData(out, c, 64));
delete (matmul_param);
}
} // namespace mindspore

View File

@ -0,0 +1,266 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include "common/common_test.h"
#include "src/common/file_utils.h"
#include "mindspore/lite/src/kernel_registry.h"
#include "mindspore/lite/src/runtime/kernel/arm/opclib/pack.h"
#include "mindspore/lite/src/runtime/kernel/arm/opclib/fp32/matmul.h"
#include "mindspore/lite/src/runtime/kernel/arm/opclib/int8/deconv.h"
#include "mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.h"
using mindspore::lite::DeviceType;
namespace mindspore {
using mindspore::lite::tensor::QuantArg;
using mindspore::lite::tensor::Tensor;
using mindspore::schema::Format_NHWC;
using mindspore::schema::NodeType_Parameter;
class TestDeconvInt8 : public mindspore::Common {
public:
TestDeconvInt8() {}
};
void FloatToInt8(float *fptr, int8_t *iptr, size_t size, int32_t zp, double scale) {
for (int i = 0; i < size; i++) {
int32_t value = round(fptr[i] / scale + zp);
value = MSMIN(value, INT8_MAX);
value = MSMAX(value, INT8_MIN);
iptr[i] = (int8_t)value;
}
}
TEST_F(TestDeconvInt8, PackWeight1) {
int8_t in[] = {-8, 11, 99, -80, 8, -12, 37, -45, 31, -69, -66, 26, 112, 124, -109, 85, -24, 28, -46, 100,
72, -36, -82, 64, -110, 37, -72, 65, -124, 91, -43, 99, 3, 100, 19, 51, -14, -81, 67, 90,
4, -106, 105, 28, -61, -79, 55, -54, 47, -38, 114, 125, -65, 100, 6, -72, -33, 60, 109, -68};
int8_t co[] = {-8, 11, 99, -80, 8, -12, 0, 0, 112, 124, -109, 85, -24, 28, 0, 0, -110, 37, -72, 65,
-124, 91, 0, 0, -14, -81, 67, 90, 4, -106, 0, 0, 47, -38, 114, 125, -65, 100, 0, 0,
37, -45, 31, -69, -66, 26, 0, 0, -46, 100, 72, -36, -82, 64, 0, 0, -43, 99, 3, 100,
19, 51, 0, 0, 105, 28, -61, -79, 55, -54, 0, 0, 6, -72, -33, 60, 109, -68, 0, 0};
int8_t dst[80] = {0};
/*5*1*2*6 nhwc*/
PackNHWCToC8HWN8Int8(in, dst, 5, 2, 6);
CompareOutputData(dst, co, 80, 1);
}
TEST_F(TestDeconvInt8, PackWeight2) {
int8_t in[] = {
40, 24, 94, 122, 67, 34, -89, 31, -43, 121, 48, -54, 44, -91, 35, 89, -37, 114, -8, 103,
-22, 32, 26, 112, -92, -23, 43, 9, 81, 118, -73, -54, 65, -99, 51, -90, 121, -62, 119, -93,
21, -92, -1, -82, -71, -54, 63, -93, 92, -93, 99, 122, -104, -16, -8, -32, 90, -126, 51, 91,
4, 70, -7, 116, 99, 81, -79, 124, -14, 28, 97, 9, -97, 99, 88, -15, 54, 26, 77, -25,
113, 119, 119, -75, -17, 7, 7, 1, 69, 66, 40, -13, 80, -115, -98, -8, -17, 31, 88, 65,
-1, -15, -98, 77, 56, 119, -20, -32, -54, -58, -16, 52, 121, 126, -33, 43, 92, -34, -17, -52,
104, -52, -91, 76, 79, 105, 102, -65, 43, 32, 13, 15, -38, 95, -18, -82, -7, 118, -79, -85,
120, -15, 2, 32, -94, 111, 115, 102, -18, 121, -106, 54, 63, 111, -16, 92, 82, -23, 111, 53,
1, -48, 45, 19, -4, -15, -72, 41, 80, -51, 116, 31, 94, 101, -10, 18, 0, -49, 108, 28,
-36, 47, -14, -2, -10, 31, -92, -84, 74, -114, -107, 66, 99, -121, -107, 31, -38, 56, -30, 109,
-7, 28, -22, -17, -3, -2, 27, -3, 108, -84, -23, -71, -54, 20, -45, 109, -42, 78, -79, 98,
-10, 57, 52, 1, 25, 73, 21, -78, 46, 121, 66, 92, 24, 55, 4, -110, -37, 112, -18, 10,
-42, 16, -9, 31, 39, -70, 108, -3, -90, -60, -121, 11, 50, -88, -104, -29, -89, 94, 64, -91,
-101, -7, 23, -57, 93, 16, 17, 35, -48, -25, 13, -121, 73, -68, -54, -122, -20, 12, 64, 20,
-11, -6, -71, -52, -97, 109, 116, -107, 117, -124, 56, 80, -108, 30, 123, 56, -80, 39, -18, -97,
-103, 122, 114, -10, -31, 97, -92, 105, -61, -25, 10, -119, -106, 41, 77, -117, 55, -83, -29, 14,
27, -106, -86, 41, 43, 23, 11, -76, -34, 121, 94, 18, 69, 73, 100, 54, 43, 32, 13, 15,
-38, 95, -18, -82, -7, 118, -79, -85, 120, -15, 2, 32, -94, 111, 115, 102, -18, 121, -106, 54,
63, 111, -16, 92, 82, -23, 111, 53, 1, -48, 45, 19, -4, -15, -72, 41, 80, -51, 116, 31,
94, 101, -10, 18, 0, -49, 108, 28, -36, 47, -14, -2, -10, 31, -92, -84, 74, -114, -107, 66,
99, -121, -107, 31, -38, 56, -30, 109, -7, 28, -22, -17, -3, -2, 27, -3, 108, -84, -23, -71,
-54, 20, -45, 109, -42, 78, -79, 98, -10, 57, 52, 1, 25, 73, 21, -78, 46, 121, 66, 92};
int8_t co[] = {
40, 24, 94, 122, 67, 34, -89, 31, -22, 32, 26, 112, -92, -23, 43, 9, 21, -92, -1, -82,
-71, -54, 63, -93, 4, 70, -7, 116, 99, 81, -79, 124, 113, 119, 119, -75, -17, 7, 7, 1,
-1, -15, -98, 77, 56, 119, -20, -32, 104, -52, -91, 76, 79, 105, 102, -65, 120, -15, 2, 32,
-94, 111, 115, 102, 1, -48, 45, 19, -4, -15, -72, 41, -36, 47, -14, -2, -10, 31, -92, -84,
-7, 28, -22, -17, -3, -2, 27, -3, -10, 57, 52, 1, 25, 73, 21, -78, -42, 16, -9, 31,
39, -70, 108, -3, -101, -7, 23, -57, 93, 16, 17, 35, -11, -6, -71, -52, -97, 109, 116, -107,
-103, 122, 114, -10, -31, 97, -92, 105, 27, -106, -86, 41, 43, 23, 11, -76, -38, 95, -18, -82,
-7, 118, -79, -85, 63, 111, -16, 92, 82, -23, 111, 53, 94, 101, -10, 18, 0, -49, 108, 28,
99, -121, -107, 31, -38, 56, -30, 109, -54, 20, -45, 109, -42, 78, -79, 98, -43, 121, 48, -54,
44, -91, 35, 89, 81, 118, -73, -54, 65, -99, 51, -90, 92, -93, 99, 122, -104, -16, -8, -32,
-14, 28, 97, 9, -97, 99, 88, -15, 69, 66, 40, -13, 80, -115, -98, -8, -54, -58, -16, 52,
121, 126, -33, 43, 43, 32, 13, 15, -38, 95, -18, -82, -18, 121, -106, 54, 63, 111, -16, 92,
80, -51, 116, 31, 94, 101, -10, 18, 74, -114, -107, 66, 99, -121, -107, 31, 108, -84, -23, -71,
-54, 20, -45, 109, 46, 121, 66, 92, 24, 55, 4, -110, -90, -60, -121, 11, 50, -88, -104, -29,
-48, -25, 13, -121, 73, -68, -54, -122, 117, -124, 56, 80, -108, 30, 123, 56, -61, -25, 10, -119,
-106, 41, 77, -117, -34, 121, 94, 18, 69, 73, 100, 54, 120, -15, 2, 32, -94, 111, 115, 102,
1, -48, 45, 19, -4, -15, -72, 41, -36, 47, -14, -2, -10, 31, -92, -84, -7, 28, -22, -17,
-3, -2, 27, -3, -10, 57, 52, 1, 25, 73, 21, -78, -37, 114, -8, 103, 0, 0, 0, 0,
121, -62, 119, -93, 0, 0, 0, 0, 90, -126, 51, 91, 0, 0, 0, 0, 54, 26, 77, -25,
0, 0, 0, 0, -17, 31, 88, 65, 0, 0, 0, 0, 92, -34, -17, -52, 0, 0, 0, 0,
-7, 118, -79, -85, 0, 0, 0, 0, 82, -23, 111, 53, 0, 0, 0, 0, 0, -49, 108, 28,
0, 0, 0, 0, -38, 56, -30, 109, 0, 0, 0, 0, -42, 78, -79, 98, 0, 0, 0, 0,
-37, 112, -18, 10, 0, 0, 0, 0, -89, 94, 64, -91, 0, 0, 0, 0, -20, 12, 64, 20,
0, 0, 0, 0, -80, 39, -18, -97, 0, 0, 0, 0, 55, -83, -29, 14, 0, 0, 0, 0,
43, 32, 13, 15, 0, 0, 0, 0, -18, 121, -106, 54, 0, 0, 0, 0, 80, -51, 116, 31,
0, 0, 0, 0, 74, -114, -107, 66, 0, 0, 0, 0, 108, -84, -23, -71, 0, 0, 0, 0,
46, 121, 66, 92, 0, 0, 0, 0};
int8_t dst[528] = {0};
PackNHWCToC8HWN8Int8(in, dst, 22, 1, 20);
CompareOutputData(dst, co, 528, 1);
}
TEST_F(TestDeconvInt8, MatMulTest1) {
int8_t a_row_major_10_12[] = {
-6, 76, 32, 80, -73, 8, -85, -3, 114, 80, 30, 42, -41, 117, 62, -76, -77, -111, 88, 105,
68, 105, -74, 13, 51, 94, 31, -52, -92, -4, -35, -71, 101, -93, 46, -65, 57, -41, -51, 77,
1, 9, 73, -19, -36, 57, 81, -24, 40, 103, 112, 109, -41, -68, 57, 61, 55, -20, 3, 2,
17, -16, -31, 58, -4, 67, -4, -95, -5, -72, 81, 15, -7, -16, -47, 112, 114, -26, -98, 53,
15, -49, 26, 19, 19, 8, -57, -35, -79, 118, 29, 21, 37, -48, 83, 7, 124, 113, -5, 15,
-8, 107, -65, -88, 50, -47, -80, -84, 3, -45, 92, 42, -20, -101, 106, -10, 89, 67, 55, 10};
int32_t zp_a = 15;
int8_t a_col8_major[16 * 12] = {0};
int8_t b_col_major_12_18[] = {
92, 27, 22, 52, -112, -20, -57, -2, 89, 32, 93, -66, -25, -54, 94, -97, -119, -98, 101, -99,
77, -83, 76, 95, 59, 97, 8, 40, -109, -20, 67, -107, 37, -6, -54, -20, -30, 36, -106, -103,
-3, -86, -82, 59, 4, -75, -50, -106, 55, 104, -117, -71, -20, -85, -77, 16, -25, -58, 4, 80,
-75, 94, 32, -68, 2, 40, 56, -103, 11, -98, -70, -69, 0, 57, -6, 82, 66, -112, -61, 33,
-77, -53, 95, -38, 87, -46, -3, 81, -47, 43, 21, 26, -45, -57, 50, -24, -82, -114, 61, 46,
-53, 78, -24, 31, -7, 37, 29, 38, 45, 106, 52, -42, 31, -6, -61, -87, 2, 79, -5, -42,
43, -106, -104, 7, 91, -63, 58, 97, -15, 74, -96, 15, -23, -3, -47, -97, 100, -54, 26, -46,
35, 26, 100, -80, 34, -25, 96, -67, -80, -27, 66, 41, 41, -43, -43, -38, -4, -64, 31, 7,
-8, 6, -2, 39, -119, 53, 75, -91, -44, 77, -62, 22, -44, 78, -67, -48, -115, -4, 43, 81,
40, -20, -5, -89, 60, -62, -4, -48, 66, -64, -69, 62, 17, -89, 1, 87, 81, 32, -29, 51,
40, 27, 66, 67, 11, -69, 85, -79, -106, 55, 22, -23, 62, 69, -74, 49};
int32_t zp_b = -20;
int8_t b_row8_major[12 * 24] = {0};
int32_t co_row_major_10_18[] = {
32005, 3597, 16595, -3458, 6627, -6663, 818, -3910, 10228, 15079, -19205, -10203, -3178, -10046,
10374, -6199, 5330, 12163, 1819, 20533, 17382, 18283, 9778, 9185, -12623, -26234, -11987, 7904,
8144, -1603, 27611, -10190, -20053, 4999, -28389, 21852, 24680, 25858, 23506, 17944, 11768, 24378,
-6102, -4675, -23460, 10434, -47579, 1986, 12018, -19418, -7248, 4938, -32613, -941, 8171, -4788,
3325, -11310, -8351, -14786, 6909, 16401, 2017, -6456, 11242, 7393, -9119, 17312, 2646, -14402,
7201, -9949, 23986, 17607, 27461, -1547, 2783, 7558, 19487, 11158, -2686, 6328, -8225, -11668,
21858, -2079, -8671, -639, -1544, 1235, 1156, 6582, 2829, -10311, -2692, 5154, 1527, 10870,
106, -8189, -24174, -1846, -15399, -3598, 14874, -5591, -619, -13667, -6053, -31103, -24499, 13008,
9143, -17982, 28437, 2176, -2114, -11631, 10779, -1032, -24690, -3112, 2125, 432, 20270, -33859,
8907, 10063, 1603, 3761, 4805, 4904, -15594, 10786, 4287, -13591, -18777, -1679, 2109, -2243,
12051, -8504, -6558, 4209, 13606, -25803, 27922, 12092, 7140, 27142, -12267, 2339, -26224, 23674,
-26579, -11398, -1823, -18976, 3641, 4415, -24878, -2045, 15937, 41465, 12601, -14513, -17619, -5728,
334, -424, 8147, -1369, 5984, 11000, 19016, 4456, -25920, 4506, 5930, 15458};
int32_t c_row8x8_major[16 * 24] = {0};
int32_t out_row_major[180] = {0};
RowMajor2Col8MajorInt8(a_row_major_10_12, a_col8_major, 10, 12);
RowMajor2Col8MajorInt8(b_col_major_12_18, b_row8_major, 18, 12);
MatMulInt8(a_col8_major, b_row8_major, c_row8x8_major, 16, 24, 12, zp_a, zp_b);
Row8x8Major2RowMajor(reinterpret_cast<float *>(c_row8x8_major), reinterpret_cast<float *>(out_row_major), 10, 18);
CompareOutputData(out_row_major, co_row_major_10_18, 180, 1);
}
TEST_F(TestDeconvInt8, PostAddTest1) {
int32_t in[] = {
-4956, -3923, 868, -8880, -4089, -5179, -4526, -4527, -10464, 99, -5826, -2995, -4519, -4519, -10509, -2505,
-11272, 434, -4522, -4523, -5287, -8936, -878, 373, -4528, -4529, -1960, -6589, 1688, 2287, -8059, 926,
-2506, -6972, -2834, -8281, -8118, -3110, -4526, -4527, -4528, -4529, -4519, -4519, -4519, -4519, -4519, -4519,
-4520, -4521, -4522, -4523, -4524, -4525, -4526, -4527, -4528, -4529, -4519, -4519, -4519, -4519, -4519, -4519,
1578, 2231, -4522, -4523, -4524, -4525, -4526, -4527, -8449, -990, -4519, -4519, -4519, -4519, -4519, -4519,
-4303, -10293, -4522, -4523, -4524, -4525, -4526, -4527, -4528, -4529, -4519, -4519, -4519, -4519, -4519, -4519,
-7025, 924, -4522, -4523, -4524, -4525, -4526, -4527, -4528, -4529, -4519, -4519, -4519, -4519, -4519, -4519,
-4520, -4521, -4522, -4523, -4524, -4525, -4526, -4527, -4528, -4529, -4519, -4519, -4519, -4519, -4519, -4519};
int8_t co[] = {-8, 11, 99, -80, 8, -12, 0, 0, 112, 124, -109, 85, -24, 28, 0, 0, -110,
37, -72, 65, -124, 91, 0, 0, -14, -81, 67, 90, 4, -106, 0, 0, 47, -38,
114, 125, -65, 100, 0, 0, 37, -45, 31, -69, -66, 26, 0, 0, -46, 100};
int32_t bias[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
int8_t out[50] = {0};
double multiplier = 0.0183649725490196;
int32_t quant_multiplier;
int32_t left_shift;
int32_t right_shift;
QuantizeRoundParameter(multiplier, &quant_multiplier, &left_shift, &right_shift);
int32_t zp = 83;
PostFuncInt8(in, bias, out, 10, 5, 8, quant_multiplier, left_shift, right_shift, zp, -128, 127);
CompareOutputData(out, co, 50, 1);
int8_t co_relu[] = {0, 11, 99, 0, 8, 0, 0, 0, 112, 124, 0, 85, 0, 28, 0, 0, 0, 37, 0, 65, 0, 91, 0, 0, 0,
0, 67, 90, 4, 0, 0, 0, 47, 0, 114, 125, 0, 100, 0, 0, 37, 0, 31, 0, 0, 26, 0, 0, 0, 100};
PostFuncInt8(in, bias, out, 10, 5, 8, quant_multiplier, left_shift, right_shift, zp, 0, 127);
CompareOutputData(out, co_relu, 50, 1);
int8_t co_relu6[] = {0, 6, 6, 0, 6, 0, 0, 0, 6, 6, 0, 6, 0, 6, 0, 0, 0, 6, 0, 6, 0, 6, 0, 0, 0,
0, 6, 6, 4, 0, 0, 0, 6, 0, 6, 6, 0, 6, 0, 0, 6, 0, 6, 0, 0, 6, 0, 0, 0, 6};
PostFuncInt8(in, bias, out, 10, 5, 8, quant_multiplier, left_shift, right_shift, zp, 0, 6);
CompareOutputData(out, co_relu6, 50, 1);
}
int DeConvInt8TestInit1(std::vector<lite::tensor::Tensor *> *inputs_, std::vector<lite::tensor::Tensor *> *outputs_,
ConvParameter *conv_param, int8_t **correct) {
/* float data from deconv fp32 testcase : DeConvTestInit2 */
/* vq = (vi - zp) * s vi = vq / s + zp */
Tensor *in_t = new Tensor(kNumberTypeInt8, {1, 4, 2, 3}, Format_NHWC, NodeType_Parameter);
in_t->MallocData();
int8_t in[] = {6, 43, 38, 24, -8, 12, 41, -24, -20, 41, -19, -6, -26, -6, 23, -31, 34, 45, 8, 45, -39, -27, -48, 12};
memcpy(in_t->Data(), in, sizeof(int8_t) * in_t->ElementsNum());
QuantArg *in_quant_arg = new QuantArg();
in_quant_arg->zeroPoint = -19, in_quant_arg->scale = 0.31228156;
in_t->AddQuantParam(*in_quant_arg);
inputs_->push_back(in_t);
Tensor *weight_t = new Tensor(kNumberTypeInt8, {3, 3, 3, 2}, Format_NHWC, NodeType_Parameter);
weight_t->MallocData();
int8_t weight[] = {66, 89, 98, 74, 95, 86, 125, 95, 105, 83, 116, 94, 90, 80, 86, 59, 72, 92,
64, 76, 92, 80, 90, 87, 106, 55, 105, 60, 75, 53, 81, 81, 98, 81, 86, 59,
74, 82, 97, 105, 71, 67, 79, 87, 72, 79, 80, 76, 96, 80, 83, 71, 61, 79};
memcpy(weight_t->Data(), weight, sizeof(int8_t) * weight_t->ElementsNum());
QuantArg *w_quant_arg = new QuantArg();
w_quant_arg->zeroPoint = 83, w_quant_arg->scale = 0.023649725490196;
weight_t->AddQuantParam(*w_quant_arg);
inputs_->push_back(weight_t);
Tensor *out_t = new Tensor(kNumberTypeInt8, {1, 7, 3, 2}, Format_NHWC, NodeType_Parameter);
out_t->MallocData();
QuantArg *out_quant_arg = new QuantArg();
out_quant_arg->zeroPoint = 31, out_quant_arg->scale = 0.3439215686275;
out_t->AddQuantParam(*out_quant_arg);
outputs_->push_back(out_t);
*correct = reinterpret_cast<int8_t *>(malloc(out_t->ElementsNum() * sizeof(int8_t)));
int8_t co_nchw[] = {57, 76, 49, 71, 8, 61, 57, 127, 56, 46, -11, 61, 23, 31, 34, 50, 59, 49, 78, 17, 6,
-3, -5, 23, -11, 6, -5, 33, 64, 30, 21, 18, 25, 21, -15, 0, 4, 31, 36, 2, 17, 43};
PackNCHWToNHWCInt8(co_nchw, *correct, out_t->Batch(), out_t->Width() * out_t->Height(), out_t->Channel());
conv_param->kernel_h_ = conv_param->kernel_w_ = 3;
conv_param->pad_h_ = conv_param->pad_w_ = 1;
conv_param->stride_h_ = conv_param->stride_w_ = 2;
conv_param->dilation_h_ = conv_param->dilation_w_ = 1;
return out_t->ElementsNum();
}
TEST_F(TestDeconvInt8, DeConvInt8Test1) {
std::vector<lite::tensor::Tensor *> inputs_;
std::vector<lite::tensor::Tensor *> outputs_;
auto deconv_param = new ConvParameter();
lite::Context *ctx = new lite::Context;
ctx->threadNum = 2;
int8_t *correct;
int total_size = DeConvInt8TestInit1(&inputs_, &outputs_, deconv_param, &correct);
mindspore::kernel::DeConvInt8CPUKernel *deconv =
new mindspore::kernel::DeConvInt8CPUKernel(reinterpret_cast<OpParameter *>(deconv_param), inputs_, outputs_, ctx);
deconv->Init();
deconv->Run();
CompareOutputData(reinterpret_cast<int8_t *>(outputs_[0]->Data()), correct, total_size, 3);
delete deconv_param;
// delete deconv;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;
free(correct);
}
} // namespace mindspore

View File

@ -27,7 +27,7 @@ namespace mindspore {
using lite::tensor::Tensor;
class TestFcInt8 : public mindspore::Common {
public:
TestFcInt8(){}
TestFcInt8() {}
};
void Quantize(float *input_data, int length, float scale, int zero_point, int8_t *output_data) {
@ -110,8 +110,7 @@ int FcInt8TestInit(std::vector<lite::tensor::Tensor *> *inputs_, std::vector<lit
matmal_param->b_transpose_ = true;
matmal_param->a_transpose_ = false;
matmal_param->has_bias_ = true;
matmal_param->minf_ = -FLT_MAX;
matmal_param->maxf_ = FLT_MAX;
matmal_param->act_type_ = ActType_No;
return out_t->ElementsNum();
}

View File

@ -0,0 +1,201 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include "include/context.h"
#include "src/ir/tensor.h"
#include "common/common_test.h"
#include "src/common/file_utils.h"
#include "src/runtime/kernel/arm/opclib/pad_parameter.h"
#include "src/runtime/kernel/arm/int8/pad_int8.h"
namespace mindspore {
using mindspore::lite::tensor::QuantArg;
using mindspore::lite::tensor::Tensor;
class TestPadInt8 : public mindspore::Common {
public:
TestPadInt8() {}
};
int PadInt8TestInit1(std::vector<Tensor *> *inputs_, std::vector<Tensor *> *outputs_, PadParameter *pad_param,
int8_t **correct) {
Tensor *in_t = new Tensor(kNumberTypeInt8, {3}, schema::Format_NHWC, schema::NodeType_Parameter);
in_t->MallocData();
int8_t in[] = {1, 1, 1};
memcpy(in_t->Data(), in, sizeof(int8_t) * in_t->ElementsNum());
QuantArg *in_quant_arg = new QuantArg();
in_quant_arg->zeroPoint = 10, in_quant_arg->scale = 0.31228156;
in_t->AddQuantParam(*in_quant_arg);
inputs_->push_back(in_t);
Tensor *out_t = new Tensor(kNumberTypeInt8, {7}, schema::Format_NHWC, schema::NodeType_Parameter);
out_t->MallocData();
QuantArg *out_quant_arg = new QuantArg();
out_quant_arg->zeroPoint = 10, out_quant_arg->scale = 0.31228156;
out_t->AddQuantParam(*out_quant_arg);
outputs_->push_back(out_t);
*correct = reinterpret_cast<int8_t *>(malloc(out_t->ElementsNum() * sizeof(int8_t)));
int8_t co[] = {10, 10, 1, 1, 1, 10, 10};
memcpy(*correct, co, out_t->ElementsNum() * sizeof(int8_t));
int padding[] = {0, 0, 0, 0, 0, 0, 2, 2};
memcpy(pad_param->paddings_, padding, MAX_PAD_SIZE * sizeof(int));
pad_param->constant_value_ = 0;
return out_t->ElementsNum();
}
TEST_F(TestPadInt8, PadInt8Test1) {
std::vector<lite::tensor::Tensor *> inputs_;
std::vector<lite::tensor::Tensor *> outputs_;
auto pad_param = new PadParameter();
lite::Context *ctx = new lite::Context;
int8_t *correct;
int total_size = PadInt8TestInit1(&inputs_, &outputs_, pad_param, &correct);
kernel::PadInt8CPUKernel *pad =
new kernel::PadInt8CPUKernel(reinterpret_cast<OpParameter *>(pad_param), inputs_, outputs_, ctx);
pad->Init();
pad->Run();
CompareOutputData(reinterpret_cast<int8_t *>(outputs_[0]->Data()), correct, total_size, 0);
delete pad_param;
delete pad;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;
free(correct);
}
int PadInt8TestInit2(std::vector<Tensor *> *inputs_, std::vector<Tensor *> *outputs_, PadParameter *pad_param,
int8_t **correct) {
Tensor *in_t = new Tensor(kNumberTypeInt8, {6, 2}, schema::Format_NHWC, schema::NodeType_Parameter);
in_t->MallocData();
int8_t in[] = {18, 71, 99, -6, 5, -119, 86, 13, 15, -85, -41, -77};
memcpy(in_t->Data(), in, sizeof(int8_t) * in_t->ElementsNum());
QuantArg *in_quant_arg = new QuantArg();
in_quant_arg->zeroPoint = 10, in_quant_arg->scale = 0.31228156;
in_t->AddQuantParam(*in_quant_arg);
inputs_->push_back(in_t);
Tensor *out_t = new Tensor(kNumberTypeInt8, {10, 5}, schema::Format_NHWC, schema::NodeType_Parameter);
out_t->MallocData();
QuantArg *out_quant_arg = new QuantArg();
out_quant_arg->zeroPoint = 10, out_quant_arg->scale = 0.31228156;
out_t->AddQuantParam(*out_quant_arg);
outputs_->push_back(out_t);
*correct = reinterpret_cast<int8_t *>(malloc(out_t->ElementsNum() * sizeof(int8_t)));
int8_t co[] = {10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 18,
71, 10, 10, 10, 99, -6, 10, 10, 10, 5, -119, 10, 10, 10, 86, 13, 10,
10, 10, 15, -85, 10, 10, 10, -41, -77, 10, 10, 10, 10, 10, 10, 10};
memcpy(*correct, co, out_t->ElementsNum() * sizeof(int8_t));
int padding[] = {0, 0, 0, 0, 3, 1, 1, 2};
memcpy(pad_param->paddings_, padding, MAX_PAD_SIZE * sizeof(int));
pad_param->constant_value_ = 0;
return out_t->ElementsNum();
}
TEST_F(TestPadInt8, PadInt8Test2) {
std::vector<lite::tensor::Tensor *> inputs_;
std::vector<lite::tensor::Tensor *> outputs_;
auto pad_param = new PadParameter();
lite::Context *ctx = new lite::Context;
int8_t *correct;
int total_size = PadInt8TestInit2(&inputs_, &outputs_, pad_param, &correct);
kernel::PadInt8CPUKernel *pad =
new kernel::PadInt8CPUKernel(reinterpret_cast<OpParameter *>(pad_param), inputs_, outputs_, ctx);
pad->Init();
pad->Run();
CompareOutputData(reinterpret_cast<int8_t *>(outputs_[0]->Data()), correct, total_size, 0);
delete pad_param;
delete pad;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;
free(correct);
}
int PadInt8TestInit4(std::vector<Tensor *> *inputs_, std::vector<Tensor *> *outputs_, PadParameter *pad_param,
int8_t **correct) {
Tensor *in_t = new Tensor(kNumberTypeInt8, {2, 3, 2, 1}, schema::Format_NHWC, schema::NodeType_Parameter);
in_t->MallocData();
int8_t in[] = {73, 24, 7, -31, -109, -2, 69, -64, 51, -45, 38, 53};
memcpy(in_t->Data(), in, sizeof(int8_t) * in_t->ElementsNum());
QuantArg *in_quant_arg = new QuantArg();
in_quant_arg->zeroPoint = 10, in_quant_arg->scale = 0.31228156;
in_t->AddQuantParam(*in_quant_arg);
inputs_->push_back(in_t);
Tensor *out_t = new Tensor(kNumberTypeInt8, {6, 6, 4, 3}, schema::Format_NHWC, schema::NodeType_Parameter);
out_t->MallocData();
QuantArg *out_quant_arg = new QuantArg();
out_quant_arg->zeroPoint = 10, out_quant_arg->scale = 0.31228156;
out_t->AddQuantParam(*out_quant_arg);
outputs_->push_back(out_t);
*correct = reinterpret_cast<int8_t *>(malloc(out_t->ElementsNum() * sizeof(int8_t)));
int8_t co[] = {
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 73, 10, 10, 24, 10, 10, 10, 10,
10, 10, 10, 10, 7, 10, 10, -31, 10, 10, 10, 10, 10, 10, 10, 10, -109, 10, 10, -2, 10, 10, 10, 10, 10, 10, 10,
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 69, 10, 10, -64, 10, 10, 10, 10, 10, 10, 10, 10, 51, 10, 10, -45, 10,
10, 10, 10, 10, 10, 10, 10, 38, 10, 10, 53, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10};
memcpy(*correct, co, out_t->ElementsNum() * sizeof(int8_t));
int padding[] = {3, 1, 1, 2, 2, 0, 1, 1};
memcpy(pad_param->paddings_, padding, MAX_PAD_SIZE * sizeof(int));
pad_param->constant_value_ = 0;
return out_t->ElementsNum();
}
TEST_F(TestPadInt8, PadInt8TestInit4) {
std::vector<lite::tensor::Tensor *> inputs_;
std::vector<lite::tensor::Tensor *> outputs_;
auto pad_param = new PadParameter();
lite::Context *ctx = new lite::Context;
int8_t *correct;
int total_size = PadInt8TestInit2(&inputs_, &outputs_, pad_param, &correct);
kernel::PadInt8CPUKernel *pad =
new kernel::PadInt8CPUKernel(reinterpret_cast<OpParameter *>(pad_param), inputs_, outputs_, ctx);
pad->Init();
pad->Run();
CompareOutputData(reinterpret_cast<int8_t *>(outputs_[0]->Data()), correct, total_size, 0);
delete pad_param;
delete pad;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;
free(correct);
}
} // namespace mindspore