!8759 [MSLITE] conv 1x1 int8 parallel support by hw and oc

From: @ling_qiao_min
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-11-20 09:20:53 +08:00 committed by Gitee
commit 25f75bd8ab
4 changed files with 197 additions and 59 deletions

View File

@ -804,8 +804,8 @@ void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int
const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, int *filter_zp) {
int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1;
matmul_func(packed_input, packed_weight, dst, row, col, deep4, col, input_sum, bias, left_shift, right_shift,
multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
matmul_func(packed_input, packed_weight, dst, row, col, deep4, conv_param->output_channel_, input_sum, bias,
left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], is_per_oc,
filter_zp);
return;

View File

@ -292,7 +292,7 @@ void MatMulInt8_4x16_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row
for (int c = 0; c < col; c++) {
int r4div = r / C4NUM, r4mod = r % C4NUM;
int c16div = c / C16NUM, c16mod = c % C16NUM;
size_t ci = r * col + c;
size_t ci = r * stride + c;
int32_t value = 0;
for (int d = 0; d < deep_4; d++) {
int d4div = d / C4NUM, d4mod = d % C4NUM;

View File

@ -17,9 +17,7 @@
#include "src/runtime/kernel/arm/int8/convolution_1x1_int8.h"
#include "src/runtime/runtime_api.h"
#include "src/common/file_utils.h"
#ifdef ENABLE_ARM64
#include "src/runtime/kernel/arm/int8/opt_op_handler.h"
#endif
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_MEMORY_FAILED;
@ -63,6 +61,60 @@ void Convolution1x1Int8CPUKernel::FreeResizeBuf() {
return;
}
int Convolution1x1Int8HwRun(void *cdata, int task_id) {
auto conv = reinterpret_cast<Convolution1x1Int8CPUKernel *>(cdata);
auto error_code = conv->HwRun(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "conv1x1 Int8 Run error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}
int Convolution1x1Int8OcRun(void *cdata, int task_id) {
auto conv = reinterpret_cast<Convolution1x1Int8CPUKernel *>(cdata);
auto error_code = conv->OcRun(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "conv1x1 Int8 Run error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}
int Convolution1x1Int8OcOptPre(void *cdata, int task_id) {
auto conv = reinterpret_cast<Convolution1x1Int8CPUKernel *>(cdata);
auto error_code = conv->OcOptPre(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "conv1x1 Int8 Run error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}
int Convolution1x1Int8CPUKernel::OcRun(int task_id) {
#ifdef ENABLE_ARM32
return RunArm32Oc(task_id);
#else
if (support_optimize_) {
return RunArm64OptOc(task_id);
} else {
return RunArm64Oc(task_id);
}
#endif
}
int Convolution1x1Int8CPUKernel::HwRun(int task_id) {
#ifdef ENABLE_ARM32
return RunArm32Hw(task_id);
#else
if (support_optimize_) {
return RunArm64OptHw(task_id);
} else {
return RunArm64Hw(task_id);
}
#endif
}
int Convolution1x1Int8CPUKernel::InitRunBuf() {
input_sum_ = reinterpret_cast<int32_t *>(ctx_->allocator->Malloc(input_sum_size_ * sizeof(int32_t)));
if (input_sum_ == nullptr) {
@ -308,9 +360,6 @@ int Convolution1x1Int8CPUKernel::InitParam() {
}
}
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, row_pack_count));
thread_stride_ = UP_DIV(UP_DIV(matmul_param_->row_, row_pack_count), thread_count_);
if (pre_trans_input_) {
input_ptr_ = reinterpret_cast<int8_t *>(malloc(matmul_param_->row_ * matmul_param_->deep_ * sizeof(int8_t)));
if (input_ptr_ == nullptr) {
@ -319,6 +368,15 @@ int Convolution1x1Int8CPUKernel::InitParam() {
}
memset(input_ptr_, 0, matmul_param_->row_ * matmul_param_->deep_ * sizeof(int8_t));
}
int hw_thread_count = UP_DIV(matmul_param_->row_, row_pack_count);
int oc_thread_count = UP_DIV(matmul_param_->col_, col_pack_count);
thread_count_hw_ = MSMIN(op_parameter_->thread_num_, hw_thread_count);
thread_stride_hw_ = UP_DIV(hw_thread_count, thread_count_hw_);
thread_count_oc_ = MSMIN(op_parameter_->thread_num_, oc_thread_count);
thread_stride_oc_ = UP_DIV(oc_thread_count, thread_count_oc_);
parallel_by_oc_ = hw_thread_count < op_parameter_->thread_num_;
return RET_OK;
}
@ -346,19 +404,19 @@ void Convolution1x1Int8CPUKernel::Pre1x1Trans(int8_t *src_input, int8_t *src_out
return;
}
int Convolution1x1Int8CPUKernel::RunArm64(int task_id) {
int cur_stride = thread_stride_ * C4NUM;
int res_stride = matmul_param_->row_ - task_id * thread_stride_ * C4NUM;
int Convolution1x1Int8CPUKernel::RunArm64Hw(int task_id) {
int cur_stride = thread_stride_hw_ * C4NUM;
int res_stride = matmul_param_->row_ - task_id * thread_stride_hw_ * C4NUM;
int cur_hw = MSMIN(cur_stride, res_stride);
if (cur_hw <= 0) {
return RET_OK;
}
int8_t *hw_in = input_ptr_ + task_id * thread_stride_ * C4NUM * conv_param_->input_channel_;
int8_t *hw_out = output_ptr_ + task_id * thread_stride_ * C4NUM * conv_param_->output_channel_;
int8_t *hw_packed_in = packed_input_ + task_id * thread_stride_ * C4NUM * matmul_param_->deep_16_;
int32_t *hw_input_sum = filter_peroc_ ? input_sum_ + task_id * thread_stride_ * C4NUM * matmul_param_->col_4_
: input_sum_ + task_id * thread_stride_ * C4NUM;
int8_t *hw_in = input_ptr_ + task_id * thread_stride_hw_ * C4NUM * conv_param_->input_channel_;
int8_t *hw_out = output_ptr_ + task_id * thread_stride_hw_ * C4NUM * conv_param_->output_channel_;
int8_t *hw_packed_in = packed_input_ + task_id * thread_stride_hw_ * C4NUM * matmul_param_->deep_16_;
int32_t *hw_input_sum = filter_peroc_ ? input_sum_ + task_id * thread_stride_hw_ * C4NUM * matmul_param_->col_4_
: input_sum_ + task_id * thread_stride_hw_ * C4NUM;
RowMajor2Row16x4MajorInt8(hw_in, hw_packed_in, cur_hw, matmul_param_->deep_);
@ -375,19 +433,19 @@ int Convolution1x1Int8CPUKernel::RunArm64(int task_id) {
return RET_OK;
}
int Convolution1x1Int8CPUKernel::RunArm32(int task_id) {
int cur_stride = thread_stride_ * C4NUM;
int res_stride = matmul_param_->row_ - task_id * thread_stride_ * C4NUM;
int Convolution1x1Int8CPUKernel::RunArm32Hw(int task_id) {
int cur_stride = thread_stride_hw_ * C4NUM;
int res_stride = matmul_param_->row_ - task_id * thread_stride_hw_ * C4NUM;
int cur_hw = MSMIN(cur_stride, res_stride);
if (cur_hw <= 0) {
return RET_OK;
}
int8_t *hw_in = input_ptr_ + task_id * thread_stride_ * C4NUM * conv_param_->input_channel_;
int8_t *hw_out = output_ptr_ + task_id * thread_stride_ * C4NUM * conv_param_->output_channel_;
int8_t *hw_packed_in = packed_input_ + task_id * thread_stride_ * C4NUM * matmul_param_->deep_16_;
int32_t *hw_input_sum = filter_peroc_ ? input_sum_ + task_id * thread_stride_ * C4NUM * matmul_param_->col_2_
: input_sum_ + task_id * thread_stride_ * C4NUM;
int8_t *hw_in = input_ptr_ + task_id * thread_stride_hw_ * C4NUM * conv_param_->input_channel_;
int8_t *hw_out = output_ptr_ + task_id * thread_stride_hw_ * C4NUM * conv_param_->output_channel_;
int8_t *hw_packed_in = packed_input_ + task_id * thread_stride_hw_ * C4NUM * matmul_param_->deep_16_;
int32_t *hw_input_sum = filter_peroc_ ? input_sum_ + task_id * thread_stride_hw_ * C4NUM * matmul_param_->col_2_
: input_sum_ + task_id * thread_stride_hw_ * C4NUM;
RowMajor2Row16x4MajorInt8(hw_in, hw_packed_in, cur_hw, matmul_param_->deep_);
@ -405,17 +463,17 @@ int Convolution1x1Int8CPUKernel::RunArm32(int task_id) {
return RET_OK;
}
int Convolution1x1Int8CPUKernel::RunArm64Opt(int task_id) {
int cur_stride = thread_stride_ * C4NUM;
int res_stride = matmul_param_->row_ - task_id * thread_stride_ * C4NUM;
int Convolution1x1Int8CPUKernel::RunArm64OptHw(int task_id) {
int cur_stride = thread_stride_hw_ * C4NUM;
int res_stride = matmul_param_->row_ - task_id * thread_stride_hw_ * C4NUM;
int cur_hw = MSMIN(cur_stride, res_stride);
if (cur_hw <= 0) {
return RET_OK;
}
int8_t *hw_in = input_ptr_ + task_id * thread_stride_ * C4NUM * conv_param_->input_channel_;
int8_t *hw_out = output_ptr_ + task_id * thread_stride_ * C4NUM * conv_param_->output_channel_;
int8_t *hw_packed_in = packed_input_ + task_id * thread_stride_ * C4NUM * matmul_param_->deep_4_;
int32_t *hw_input_sum = input_sum_ + task_id * thread_stride_ * C4NUM;
int8_t *hw_in = input_ptr_ + task_id * thread_stride_hw_ * C4NUM * conv_param_->input_channel_;
int8_t *hw_out = output_ptr_ + task_id * thread_stride_hw_ * C4NUM * conv_param_->output_channel_;
int8_t *hw_packed_in = packed_input_ + task_id * thread_stride_hw_ * C4NUM * matmul_param_->deep_4_;
int32_t *hw_input_sum = input_sum_ + task_id * thread_stride_hw_ * C4NUM;
if (filter_peroc_) {
PackInput4x4AndInputSumPert(hw_in, hw_packed_in, hw_input_sum, matmul_param_->deep_, cur_hw, 1);
@ -431,24 +489,84 @@ int Convolution1x1Int8CPUKernel::RunArm64Opt(int task_id) {
return RET_OK;
}
int Convolution1x1Int8CPUKernel::DoRun(int task_id) {
#ifdef ENABLE_ARM32
return RunArm32(task_id);
#else
if (support_optimize_) {
return RunArm64Opt(task_id);
} else {
return RunArm64(task_id);
int Convolution1x1Int8CPUKernel::RunArm32Oc(int task_id) {
int stride = thread_stride_oc_ * C2NUM;
int cur_stride = task_id * stride;
int res_stride = matmul_param_->col_ - cur_stride;
int cur_oc = MSMIN(stride, res_stride);
if (cur_oc <= 0) {
return RET_OK;
}
#endif
int32_t *cur_input_sum = filter_peroc_ ? input_sum_ + cur_stride * matmul_param_->row_4_ : input_sum_;
int32_t *cur_left_shift = filter_peroc_ ? left_shift_ + cur_stride : conv_param_->conv_quant_arg_.left_shift_;
int32_t *cur_right_shift = filter_peroc_ ? right_shift_ + cur_stride : conv_param_->conv_quant_arg_.right_shift_;
int32_t *cur_multiplier = filter_peroc_ ? multiplier_ + cur_stride : conv_param_->conv_quant_arg_.quant_multiplier_;
Conv1x1Int8Arm32(packed_input_, packed_weight_ + cur_stride * matmul_param_->deep_16_, output_ptr_ + cur_stride,
cur_input_sum, reinterpret_cast<int32_t *>(bias_data_) + cur_stride, matmul_param_->row_, cur_oc,
matmul_param_->deep_16_, cur_left_shift, cur_right_shift, cur_multiplier, conv_param_);
return RET_OK;
}
int Convolution1x1Int8Run(void *cdata, int task_id) {
auto conv = reinterpret_cast<Convolution1x1Int8CPUKernel *>(cdata);
auto error_code = conv->DoRun(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "conv1x1 Int8 Run error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
int Convolution1x1Int8CPUKernel::RunArm64OptOc(int task_id) {
int stride = thread_stride_oc_ * C16NUM;
int cur_stride = task_id * stride;
int res_stride = matmul_param_->col_ - cur_stride;
int cur_oc = MSMIN(stride, res_stride);
if (cur_oc <= 0) {
return RET_OK;
}
int32_t *cur_left_shift = filter_peroc_ ? left_shift_ + cur_stride : conv_param_->conv_quant_arg_.left_shift_;
int32_t *cur_right_shift = filter_peroc_ ? right_shift_ + cur_stride : conv_param_->conv_quant_arg_.right_shift_;
int32_t *cur_multiplier = filter_peroc_ ? multiplier_ + cur_stride : conv_param_->conv_quant_arg_.quant_multiplier_;
Conv1x1Int8Opt(packed_input_, packed_weight_ + cur_stride * matmul_param_->deep_4_, output_ptr_ + cur_stride,
input_sum_, reinterpret_cast<int32_t *>(bias_data_) + cur_stride, matmul_param_->row_, cur_oc,
matmul_param_->deep_4_, cur_left_shift, cur_right_shift, cur_multiplier, conv_param_, matmul_func_,
filter_zp_ptr_);
return RET_OK;
}
int Convolution1x1Int8CPUKernel::RunArm64Oc(int task_id) {
int stride = thread_stride_oc_ * C4NUM;
int cur_stride = task_id * stride;
int res_stride = matmul_param_->col_ - cur_stride;
int cur_oc = MSMIN(stride, res_stride);
if (cur_oc <= 0) {
return RET_OK;
}
int32_t *cur_input_sum = filter_peroc_ ? input_sum_ + cur_stride * matmul_param_->row_4_ : input_sum_;
int32_t *cur_left_shift = filter_peroc_ ? left_shift_ + cur_stride : conv_param_->conv_quant_arg_.left_shift_;
int32_t *cur_right_shift = filter_peroc_ ? right_shift_ + cur_stride : conv_param_->conv_quant_arg_.right_shift_;
int32_t *cur_multiplier = filter_peroc_ ? multiplier_ + cur_stride : conv_param_->conv_quant_arg_.quant_multiplier_;
Conv1x1Int8(packed_input_, packed_weight_ + cur_stride * matmul_param_->deep_16_, output_ptr_ + cur_stride,
cur_input_sum, reinterpret_cast<int32_t *>(bias_data_) + cur_stride, matmul_param_->row_, cur_oc,
matmul_param_->deep_16_, cur_left_shift, cur_right_shift, cur_multiplier, conv_param_);
return RET_OK;
}
int Convolution1x1Int8CPUKernel::OcOptPre(int task_id) {
int cur_stride = thread_stride_hw_ * C4NUM;
int res_stride = matmul_param_->row_ - task_id * thread_stride_hw_ * C4NUM;
int cur_hw = MSMIN(cur_stride, res_stride);
if (cur_hw <= 0) {
return RET_OK;
}
int8_t *hw_in = input_ptr_ + task_id * thread_stride_hw_ * C4NUM * conv_param_->input_channel_;
int8_t *hw_packed_in = packed_input_ + task_id * thread_stride_hw_ * C4NUM * matmul_param_->deep_4_;
int32_t *hw_input_sum = input_sum_ + task_id * thread_stride_hw_ * C4NUM;
if (filter_peroc_) {
PackInput4x4AndInputSumPert(hw_in, hw_packed_in, hw_input_sum, matmul_param_->deep_, cur_hw, 1);
} else {
PackInput4x4AndInputSumPert(hw_in, hw_packed_in, hw_input_sum, matmul_param_->deep_, cur_hw,
conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_);
}
return RET_OK;
}
@ -461,22 +579,34 @@ int Convolution1x1Int8CPUKernel::Run() {
return RET_ERROR;
}
int8_t *src_in = reinterpret_cast<int8_t *>(in_tensors_[0]->MutableData());
int8_t *src_out = reinterpret_cast<int8_t *>(out_tensors_[0]->MutableData());
int8_t *src_in = reinterpret_cast<int8_t *>(in_tensors_[0]->data_c());
int8_t *src_out = reinterpret_cast<int8_t *>(out_tensors_[0]->data_c());
for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) {
Pre1x1Trans(src_in + batch_index * conv_param_->input_h_ * conv_param_->input_w_ * conv_param_->input_channel_,
src_out + batch_index * matmul_param_->row_ * matmul_param_->col_);
auto ret = ParallelLaunch(this->context_->thread_pool_, Convolution1x1Int8Run, this, thread_count_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ParallelLaunch run error error_code[" << ret << "]";
if (parallel_by_oc_) {
/* input transpose and input sum */
if (support_optimize_) {
ParallelLaunch(this->context_->thread_pool_, Convolution1x1Int8OcOptPre, this, thread_count_hw_);
} else {
RowMajor2Row16x4MajorInt8(input_ptr_, packed_input_, matmul_param_->row_, matmul_param_->deep_);
PackInputSum16x4Int8(packed_input_, input_sum_, filter_zp_ptr_, conv_param_);
}
/* matmul parallel by oc */
error_code = ParallelLaunch(this->context_->thread_pool_, Convolution1x1Int8OcRun, this, thread_count_oc_);
} else {
/* matmul parallel by hw */
error_code = ParallelLaunch(this->context_->thread_pool_, Convolution1x1Int8HwRun, this, thread_count_hw_);
}
if (error_code != RET_OK) {
MS_LOG(ERROR) << "ParallelLaunch run error error_code[" << error_code << "]";
FreeRunBuf();
return ret;
return error_code;
}
}
FreeRunBuf();
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -45,12 +45,17 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel {
void FreeRunBuf();
public:
int DoRun(int task_id);
int OcRun(int task_id);
int HwRun(int task_id);
int OcOptPre(int task_id);
private:
int RunArm32(int task_id);
int RunArm64(int task_id);
int RunArm64Opt(int task_id);
int RunArm32Oc(int task_id);
int RunArm64Oc(int task_id);
int RunArm64OptOc(int task_id);
int RunArm32Hw(int task_id);
int RunArm64Hw(int task_id);
int RunArm64OptHw(int task_id);
private:
void FreeResizeBuf();
@ -71,9 +76,12 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel {
int8_t *packed_input_ = nullptr;
int8_t *input_ptr_ = nullptr;
int8_t *output_ptr_ = nullptr;
size_t thread_count_ = 1;
size_t thread_stride_ = 0;
size_t thread_count_hw_ = 1;
size_t thread_stride_hw_ = 0;
size_t thread_count_oc_ = 1;
size_t thread_stride_oc_ = 0;
bool pre_trans_input_ = false;
bool parallel_by_oc_ = false;
size_t input_sum_size_ = 0;
MatMulParameter *matmul_param_ = nullptr;
MATMUL_OPT_DP_FUNC matmul_func_ = nullptr;