!7355 [MS][LITE][CPU]replace int8 common conv with int8 matmul

Merge pull request !7355 from fuzhiye/tmp
This commit is contained in:
mindspore-ci-bot 2020-10-19 14:02:04 +08:00 committed by Gitee
commit da43d8e47d
7 changed files with 133 additions and 193 deletions

View File

@ -313,9 +313,9 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, c
} }
} }
void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, const int32_t *bias_data, void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int8_t *packed_weight,
int32_t *tmp_dst, int8_t *tmp_out, int8_t *output_data, int32_t *input_sum, int task_id, const int32_t *bias_data, int8_t *output_data, int32_t *filter_zp, int32_t *input_sum, int task_id,
ConvParameter *conv_param, GEMM_FUNC gemm_func) { ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func) {
int kernel_h = conv_param->kernel_h_; int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_; int kernel_w = conv_param->kernel_w_;
int in_batch = conv_param->input_batch_; int in_batch = conv_param->input_batch_;
@ -325,20 +325,22 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight
int out_h = conv_param->output_h_; int out_h = conv_param->output_h_;
int out_w = conv_param->output_w_; int out_w = conv_param->output_w_;
int out_channel = conv_param->output_channel_; int out_channel = conv_param->output_channel_;
int oc4 = UP_DIV(out_channel, C4NUM); int oc8 = UP_DIV(out_channel, C8NUM);
int32_t input_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_;
int tile_n = conv_param->tile_num_; int tile_n = conv_param->tile_num_;
int thread_count = conv_param->thread_num_; int thread_count = conv_param->thread_num_;
int output_count = out_h * out_w; int output_count = out_h * out_w;
int output_tile_count = UP_DIV(output_count, tile_n); int output_tile_count = UP_DIV(output_count, tile_n);
int ic4 = UP_DIV(in_channel, C4NUM); int ic4 = UP_DIV(in_channel, C4NUM);
int kernel_plane = kernel_h * kernel_w; int kernel_plane = kernel_h * kernel_w;
int unit_size = kernel_plane * ic4 * C4NUM; int unit_size = UP_ROUND(kernel_plane * in_channel, C4NUM);
int input_sum_offset; int input_sum_offset;
bool per_channel;
if (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL) { if (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL) {
input_sum_offset = tile_n * oc4 * C4NUM; input_sum_offset = tile_n * oc8 * C8NUM;
per_channel = true;
} else { } else {
input_sum_offset = tile_n; input_sum_offset = tile_n;
per_channel = false;
} }
for (int b = 0; b < in_batch; b++) { for (int b = 0; b < in_batch; b++) {
@ -349,27 +351,18 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight
int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n; int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n;
int32_t *tmp_input_sum = input_sum + task_id * input_sum_offset; int32_t *tmp_input_sum = input_sum + task_id * input_sum_offset;
int8_t *gemm_input = packed_input + task_id * unit_size * tile_n; int8_t *gemm_input = packed_input + task_id * unit_size * tile_n;
// clear tmp buffer before compute int8_t *matmul = matmul_input + task_id * kernel_plane * in_channel * tile_n;
memset(gemm_input, (int8_t)input_zp, unit_size * tile_n); memset(matmul, conv_param->conv_quant_arg_.input_quant_args_[0].zp_, kernel_plane * in_channel * tile_n);
Im2ColPackUnitInt8Opt(input_data + in_batch_offset, gemm_input, matmul, real_cal_num, start_index, filter_zp,
tmp_input_sum, conv_param, per_channel);
int out_offset = thread_id * tile_n * out_channel + out_batch_offset; int out_offset = thread_id * tile_n * out_channel + out_batch_offset;
int8_t *gemm_output = output_data + out_offset;
size_t tmp_dst_size = tile_n * conv_param->output_channel_ * sizeof(int32_t); matmul_func(gemm_input, packed_weight, gemm_output, real_cal_num, out_channel, unit_size, out_channel,
int tmp_dst_offset = task_id * tile_n * conv_param->output_channel_; tmp_input_sum, bias_data, conv_param->conv_quant_arg_.left_shift_,
memset(tmp_dst + tmp_dst_offset, 0, tmp_dst_size); conv_param->conv_quant_arg_.right_shift_, conv_param->conv_quant_arg_.quant_multiplier_,
conv_param->conv_quant_arg_.output_quant_args_[0].zp_, conv_param->conv_quant_arg_.out_act_min_[0],
Im2ColPackUnitInt8Opt(input_data + in_batch_offset, gemm_input, real_cal_num, start_index, tmp_input_sum, conv_param->conv_quant_arg_.out_act_max_[0], per_channel);
conv_param);
if (real_cal_num == tile_n) {
int8_t *gemm_output = output_data + out_offset;
IndirectGemmInt8Opt(gemm_output, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4,
kernel_plane, out_channel, tmp_input_sum, conv_param, gemm_func);
} else {
// res part
int8_t *tmp_out_ptr = tmp_out + task_id * tile_n * out_channel;
IndirectGemmInt8Opt(tmp_out_ptr, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4,
kernel_plane, out_channel, tmp_input_sum, conv_param, gemm_func);
memcpy(output_data + out_offset, tmp_out_ptr, real_cal_num * out_channel);
}
} }
} }
} }
@ -708,6 +701,13 @@ void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *
pack_ic += 1; pack_ic += 1;
} }
for (int ici = input_channel; ici < ic4; ici += 1) {
for (int i = 0; i < C8NUM; i++) {
pack_ic[i * C4NUM] = 0;
}
pack_ic += 1;
}
for (int oci = 0; oci < oc_8div; oci += C8NUM) { for (int oci = 0; oci < oc_8div; oci += C8NUM) {
for (int ri = 0; ri < C8NUM; ri++) { for (int ri = 0; ri < C8NUM; ri++) {
input_sum_oc[ri * C8NUM + 0] = tmp_sum_value[ri] * filter_zp[oci + 0]; input_sum_oc[ri * C8NUM + 0] = tmp_sum_value[ri] * filter_zp[oci + 0];
@ -975,6 +975,13 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i
pack_ic += 1; pack_ic += 1;
} }
for (int ici = input_channel; ici < ic4; ici += 1) {
for (int i = 0; i < C8NUM; i++) {
pack_ic[i * C4NUM] = 0;
}
pack_ic += 1;
}
for (int i = 0; i < C8NUM; i++) { for (int i = 0; i < C8NUM; i++) {
input_sum_r[i] = tmp_sum_value[i] * filter_zp; input_sum_r[i] = tmp_sum_value[i] * filter_zp;
} }

View File

@ -49,9 +49,9 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, c
int32_t *tmp_dst, int8_t *tmp_out, int8_t *output_data, int32_t *input_sum, int task_id, int32_t *tmp_dst, int8_t *tmp_out, int8_t *output_data, int32_t *input_sum, int task_id,
ConvParameter *conv_param); ConvParameter *conv_param);
void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, const int32_t *bias_data, void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int8_t *packed_weight,
int32_t *tmp_dst, int8_t *tmp_out, int8_t *output_data, int32_t *input_sum, int task_id, const int32_t *bias_data, int8_t *output_data, int32_t *filter_zp, int32_t *input_sum, int task_id,
ConvParameter *conv_param, GEMM_FUNC gemm_func); ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func);
// int8 convolution 1x1 // int8 convolution 1x1
void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,

View File

@ -17,6 +17,7 @@
#include "nnacl/pack.h" #include "nnacl/pack.h"
#include <string.h> #include <string.h>
#include <stdlib.h> #include <stdlib.h>
#include "nnacl/int8/conv_int8.h"
void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel) { void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel) {
return PackNCHWToNHWCFp32(src, dst, 1, plane, channel); return PackNCHWToNHWCFp32(src, dst, 1, plane, channel);
@ -80,54 +81,6 @@ void PackWeightInt8(int8_t *weight_data, ConvParameter *conv_param, int8_t *pack
} // kernel plane loop } // kernel plane loop
} }
void PackWeightInt8Opt(int8_t *weight_data, ConvParameter *conv_param, int8_t *packed_weight, int32_t *weight_sum) {
// original weight format : ohwi
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int in_channel = conv_param->input_channel_;
int out_channel = conv_param->output_channel_;
int oc4 = UP_DIV(out_channel, C4NUM);
int ic4 = UP_DIV(in_channel, C4NUM);
int kernel_plane = kernel_h * kernel_w;
int pack_weight_size = oc4 * ic4 * C4NUM * C4NUM * kernel_plane;
int unit_size = C4NUM * C4NUM;
int block_size = pack_weight_size / oc4;
QuantArg *filter_args = conv_param->conv_quant_arg_.filter_quant_args_;
for (int m = 0; m < kernel_plane; m++) {
int kernel_plane_stride = m * in_channel;
int packed_kernel_plane_stride = m * unit_size * ic4;
for (int i = 0; i < ic4; i++) {
int channel_block_stride = kernel_plane_stride + i * C4NUM;
int packed_channel_block_size = packed_kernel_plane_stride + i * unit_size;
int ic_remainder = in_channel - i * C4NUM;
int real_ic_num = ic_remainder < C4NUM ? ic_remainder : C4NUM;
for (int h = 0; h < real_ic_num; h++) {
int block_stride = channel_block_stride + h;
int packed_block_stride = packed_channel_block_size + h;
for (int j = 0; j < oc4; j++) {
int kernel_block_stride = block_stride + j * C4NUM * kernel_plane * in_channel;
int packed_kernel_block_size = packed_block_stride + j * block_size;
int oc_remainder = out_channel - j * C4NUM;
int real_oc_num = oc_remainder < C4NUM ? oc_remainder : C4NUM;
for (int k = 0; k < real_oc_num; k++) {
int8_t *origin_data_ptr = weight_data + kernel_block_stride + k * kernel_plane * in_channel;
int8_t *packed_data_ptr = packed_weight + packed_kernel_block_size + k * C4NUM;
*packed_data_ptr = origin_data_ptr[0];
int32_t f_zp;
if (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL) {
f_zp = filter_args[j * C4NUM + k].zp_;
} else {
f_zp = filter_args[0].zp_;
}
weight_sum[j * C4NUM + k] += (int32_t)(packed_data_ptr[0] - f_zp);
}
} // kernel block loop
} // inchannel block loop
} // channel block loop
} // kernel plane loop
}
void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size) { void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size) {
/* support nhwc */ /* support nhwc */
char *src = (char *)src_ptr; char *src = (char *)src_ptr;
@ -391,11 +344,10 @@ void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real
} // tile num loop } // tile num loop
} }
void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index, void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int real_cal_num,
int32_t *input_sum, ConvParameter *conv_param) { int block_index, int32_t *filter_zp, int32_t *input_sum, ConvParameter *conv_param,
bool per_channel) {
// input format : nhwc // input format : nhwc
int tile_num = conv_param->tile_num_;
QuantArg *filter_arg = conv_param->conv_quant_arg_.filter_quant_args_;
int kernel_h = conv_param->kernel_h_; int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_; int kernel_w = conv_param->kernel_w_;
int stride_h = conv_param->stride_h_; int stride_h = conv_param->stride_h_;
@ -407,11 +359,8 @@ void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int r
int in_channel = conv_param->input_channel_; int in_channel = conv_param->input_channel_;
int in_h = conv_param->input_h_; int in_h = conv_param->input_h_;
int in_w = conv_param->input_w_; int in_w = conv_param->input_w_;
int ic4_minus = in_channel / C4NUM;
int ic4 = UP_DIV(in_channel, C4NUM);
int oc4 = UP_DIV(conv_param->output_channel_, C4NUM);
int out_w = conv_param->output_w_; int out_w = conv_param->output_w_;
int block_size = kernel_h * kernel_w; int kernel_plane = kernel_h * kernel_w;
for (int i = 0; i < real_cal_num; i++) { for (int i = 0; i < real_cal_num; i++) {
int block_start = block_index + i; int block_start = block_index + i;
@ -422,47 +371,30 @@ void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int r
int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h)); int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h));
int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w)); int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w));
int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w)); int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w));
for (int j = kh_s; j < kh_e; j++) { if (dilation_w == 1 && dilation_h == 1) {
int input_y_stride = j * dilation_h * in_w * in_channel + input_stride; for (int j = kh_s; j < kh_e; j++) {
for (int n = kw_s; n < kw_e; n++) { int input_y_stride = j * in_w * in_channel + input_stride;
int input_x_stride = input_y_stride + n * dilation_w * in_channel; int input_x_stride = input_y_stride + kw_s * in_channel;
int input_plane_offset = (j * kernel_w + n) * tile_num * C4NUM * ic4 + i * C4NUM; int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane;
for (int m = 0; m < ic4_minus; m++) { memcpy(matmul_input + input_plane_offset, input_data + input_x_stride, (kw_e - kw_s) * in_channel);
int channel_block_stride = input_x_stride + m * C4NUM; } // kernel_h loop
int channel_block_offset = input_plane_offset + m * tile_num * C4NUM; } else {
memcpy(packed_input + channel_block_offset, input_data + channel_block_stride, 4); for (int j = kh_s; j < kh_e; j++) {
} // channel_block loop int input_y_stride = j * dilation_h * in_w * in_channel + input_stride;
int ic_res = conv_param->input_channel_ - ic4_minus * C4NUM; for (int k = kw_s; k < kw_e; ++k) {
for (int l = 0; l < ic_res; ++l) { int input_x_stride = input_y_stride + k * dilation_w * in_channel;
int channel_block_stride = input_x_stride + ic4_minus * C4NUM + l; int input_plane_offset = (j * kernel_w + k) * in_channel + i * in_channel * kernel_plane;
int channel_block_offset = input_plane_offset + ic4_minus * tile_num * C4NUM + l; memcpy(matmul_input + input_plane_offset, input_data + input_x_stride, in_channel);
packed_input[channel_block_offset] = input_data[channel_block_stride];
} }
} // kernel_w loop } // kernel_h loop
} // kernel_h loop
int32_t input_accumulator = 0;
for (int j = 0; j < block_size; j++) {
int block_offset = j * tile_num * ic4 * C4NUM + i * C4NUM;
for (int c = 0; c < ic4; c++) {
int ic4_offset = block_offset + c * tile_num * C4NUM;
for (int k = 0; k < C4NUM; ++k) {
input_accumulator += (packed_input + ic4_offset)[k];
}
}
}
if (!(conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC)) {
continue;
} else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
int cal_num_offset = i * oc4 * C4NUM;
for (int l = 0; l < conv_param->output_channel_; ++l) {
input_sum[cal_num_offset + l] = input_accumulator * filter_arg[l].zp_;
}
} else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
!(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
input_sum[i] = input_accumulator * filter_arg[0].zp_;
} }
} // tile num loop } // tile num loop
if (per_channel) {
Conv1x1PreOptPeroc(matmul_input, packed_input, input_sum, kernel_plane * in_channel, conv_param->output_channel_,
real_cal_num, filter_zp, C8NUM * C8NUM);
} else {
Conv1x1PreOptPert(matmul_input, packed_input, input_sum, kernel_plane * in_channel, real_cal_num, conv_param);
}
} }
void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, ConvParameter *conv_param) { void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, ConvParameter *conv_param) {

View File

@ -35,8 +35,9 @@ void PackHWCToWHC(const float *src, float *dst, int height, int width, int chann
void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index, void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index,
int32_t *input_sum, ConvParameter *conv_param); int32_t *input_sum, ConvParameter *conv_param);
void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index, void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int real_cal_num,
int32_t *input_sum, ConvParameter *conv_param); int block_index, int32_t *filter_zp, int32_t *input_sum, ConvParameter *conv_param,
bool per_channel);
void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16); void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16);
@ -52,8 +53,6 @@ void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel);
void PackWeightInt8(int8_t *weight_data, ConvParameter *conv_param, int8_t *packed_weight, int32_t *weight_sum); void PackWeightInt8(int8_t *weight_data, ConvParameter *conv_param, int8_t *packed_weight, int32_t *weight_sum);
void PackWeightInt8Opt(int8_t *weight_data, ConvParameter *conv_param, int8_t *packed_weight, int32_t *weight_sum);
void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight_data, ConvParameter *conv_param); void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight_data, ConvParameter *conv_param);
void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel); void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel);

View File

@ -32,7 +32,8 @@ using mindspore::schema::PrimitiveType_Conv2D;
namespace mindspore::kernel { namespace mindspore::kernel {
void ConvolutionInt8CPUKernel::CheckSupportOptimize() { void ConvolutionInt8CPUKernel::CheckSupportOptimize() {
tile_num_ = 24; tile_num_ = 8;
matmul_func_ = MatMulInt8_8x8_r;
#ifdef ENABLE_ARM32 #ifdef ENABLE_ARM32
tile_num_ = 2; tile_num_ = 2;
support_optimize_ = false; support_optimize_ = false;
@ -42,19 +43,19 @@ void ConvolutionInt8CPUKernel::CheckSupportOptimize() {
void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_; void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_;
if (optimize_op_handler != nullptr) { if (optimize_op_handler != nullptr) {
dlerror(); dlerror();
*(reinterpret_cast<void **>(&gemm_func_)) = dlsym(optimize_op_handler, "IndirectGemmInt8_optimize_handler"); *(reinterpret_cast<void **>(&matmul_func_)) = dlsym(optimize_op_handler, "MatMulRInt8_optimize_handler");
auto dlopen_error = dlerror(); auto dlopen_error = dlerror();
if (dlopen_error != nullptr) { if (dlopen_error != nullptr) {
MS_LOG(ERROR) << "load gemm func failed! " << dlopen_error << "."; MS_LOG(ERROR) << "load matmul func failed! " << dlopen_error << ".";
tile_num_ = 4;
support_optimize_ = false; support_optimize_ = false;
gemm_func_ = nullptr; matmul_func_ = nullptr;
} else { } else {
// do nothing support_optimize_ = true;
} }
} else { } else {
tile_num_ = 4; tile_num_ = 4;
support_optimize_ = false; support_optimize_ = false;
matmul_func_ = nullptr;
} }
#endif #endif
conv_param_->tile_num_ = tile_num_; conv_param_->tile_num_ = tile_num_;
@ -141,7 +142,6 @@ int ConvolutionInt8CPUKernel::InitWeightBias() {
int ConvolutionInt8CPUKernel::InitTmpBuffer() { int ConvolutionInt8CPUKernel::InitTmpBuffer() {
MS_ASSERT(ctx_->allocator != nullptr); MS_ASSERT(ctx_->allocator != nullptr);
int ic4 = UP_DIV(conv_param_->input_channel_, C4NUM); int ic4 = UP_DIV(conv_param_->input_channel_, C4NUM);
int kernel_plane = conv_param_->kernel_h_ * conv_param_->kernel_w_; int kernel_plane = conv_param_->kernel_h_ * conv_param_->kernel_w_;
int plane_c4 = UP_DIV(kernel_plane, C4NUM); int plane_c4 = UP_DIV(kernel_plane, C4NUM);
@ -176,11 +176,10 @@ int ConvolutionInt8CPUKernel::InitWeightBiasOpt() {
int kernel_w = filter_tensor->Width(); int kernel_w = filter_tensor->Width();
conv_param_->input_channel_ = input_channel; conv_param_->input_channel_ = input_channel;
conv_param_->output_channel_ = output_channel; conv_param_->output_channel_ = output_channel;
int ic4 = UP_DIV(input_channel, C4NUM); int oc8 = UP_DIV(output_channel, C8NUM);
int oc4 = UP_DIV(output_channel, C4NUM);
int kernel_plane = kernel_h * kernel_w; int kernel_plane = kernel_h * kernel_w;
int pack_weight_size = oc4 * ic4 * C4NUM * C4NUM * kernel_plane; int up_round_deep = UP_ROUND(kernel_plane * input_channel, C4NUM);
auto filter_arg = conv_param_->conv_quant_arg_.filter_quant_args_; int pack_weight_size = oc8 * C8NUM * up_round_deep;
int32_t input_zp = conv_param_->conv_quant_arg_.input_quant_args_[0].zp_; int32_t input_zp = conv_param_->conv_quant_arg_.input_quant_args_[0].zp_;
// init weight // init weight
@ -191,27 +190,15 @@ int ConvolutionInt8CPUKernel::InitWeightBiasOpt() {
return RET_ERROR; return RET_ERROR;
} }
memset(packed_weight_, 0, pack_weight_size); memset(packed_weight_, 0, pack_weight_size);
auto *weight_sum = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t) * output_channel)); RowMajor2Row8x4MajorInt8(origin_weight, packed_weight_, output_channel, input_channel * kernel_plane);
if (weight_sum == nullptr) {
MS_LOG(ERROR) << "malloc weight_sum failed.";
return RET_ERROR;
}
for (int i = 0; i < output_channel; i++) {
if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) {
weight_sum[i] = ic4 * C4NUM * kernel_plane * filter_arg[i].zp_;
} else {
weight_sum[i] = ic4 * C4NUM * kernel_plane * filter_arg[0].zp_;
}
}
PackWeightInt8Opt(origin_weight, conv_param_, packed_weight_, weight_sum);
// init bias // init bias
bias_data_ = reinterpret_cast<int32_t *>(malloc(oc4 * C4NUM * sizeof(int32_t))); bias_data_ = reinterpret_cast<int32_t *>(malloc(oc8 * C8NUM * sizeof(int32_t)));
if (bias_data_ == nullptr) { if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "malloc bias_data_ failed."; MS_LOG(ERROR) << "malloc bias_data_ failed.";
return RET_ERROR; return RET_ERROR;
} }
memset(bias_data_, 0, oc4 * C4NUM * sizeof(int32_t)); memset(bias_data_, 0, oc8 * C8NUM * sizeof(int32_t));
if (in_tensors_.size() == kInputSize2) { if (in_tensors_.size() == kInputSize2) {
auto ori_bias = reinterpret_cast<int32_t *>(in_tensors_.at(kBiasIndex)->MutableData()); auto ori_bias = reinterpret_cast<int32_t *>(in_tensors_.at(kBiasIndex)->MutableData());
memcpy(bias_data_, ori_bias, output_channel * sizeof(int32_t)); memcpy(bias_data_, ori_bias, output_channel * sizeof(int32_t));
@ -219,21 +206,26 @@ int ConvolutionInt8CPUKernel::InitWeightBiasOpt() {
MS_ASSERT(in_tensors_.size() == kInputSize1); MS_ASSERT(in_tensors_.size() == kInputSize1);
} }
auto *bias_data = reinterpret_cast<int32_t *>(bias_data_); auto *bias_data = reinterpret_cast<int32_t *>(bias_data_);
int c4_kernel_plane_size = kernel_plane * ic4 * C4NUM; bool filter_peroc = conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL;
if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) { if (filter_peroc) {
for (int i = 0; i < output_channel; i++) { filter_zp_ptr_ = reinterpret_cast<int32_t *>(malloc(output_channel * sizeof(int32_t)));
bias_data[i] += filter_arg[i].zp_ * input_zp * c4_kernel_plane_size - weight_sum[i] * input_zp; }
} for (int oc = 0; oc < output_channel; oc++) {
} else { int32_t filter_zp = conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_;
for (int i = 0; i < output_channel; i++) { if (filter_peroc) {
bias_data[i] += filter_arg[0].zp_ * input_zp * c4_kernel_plane_size - weight_sum[i] * input_zp; filter_zp = conv_param_->conv_quant_arg_.filter_quant_args_[oc].zp_;
} filter_zp_ptr_[oc] = filter_zp;
}
int32_t weight_sum_value = up_round_deep * filter_zp;
for (int i = 0; i < kernel_plane * input_channel; i++) {
weight_sum_value += origin_weight[oc * kernel_plane * input_channel + i] - filter_zp;
}
bias_data[oc] += filter_zp * input_zp * up_round_deep - weight_sum_value * input_zp;
} }
free(weight_sum);
size_t input_sum_size; size_t input_sum_size;
if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) { if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) {
input_sum_size = oc4 * C4NUM * tile_num_ * thread_count_ * sizeof(int32_t); input_sum_size = oc8 * C8NUM * tile_num_ * thread_count_ * sizeof(int32_t);
} else { } else {
input_sum_size = tile_num_ * thread_count_ * sizeof(int32_t); input_sum_size = tile_num_ * thread_count_ * sizeof(int32_t);
} }
@ -248,26 +240,17 @@ int ConvolutionInt8CPUKernel::InitWeightBiasOpt() {
int ConvolutionInt8CPUKernel::InitTmpBufferOpt() { int ConvolutionInt8CPUKernel::InitTmpBufferOpt() {
MS_ASSERT(ctx_->allocator != nullptr); MS_ASSERT(ctx_->allocator != nullptr);
int kernel_plane = conv_param_->kernel_h_ * conv_param_->kernel_w_;
int ic4 = UP_DIV(conv_param_->input_channel_, C4NUM); int tmp_unit = UP_ROUND(kernel_plane * conv_param_->input_channel_, C4NUM);
size_t nhwc4_input_size = ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_; matmul_packed_input_ = reinterpret_cast<int8_t *>(
nhwc4_input_ = ctx_->allocator->Malloc(nhwc4_input_size); ctx_->allocator->Malloc(thread_count_ * tile_num_ * kernel_plane * conv_param_->input_channel_));
if (nhwc4_input_ == nullptr) { if (matmul_packed_input_ == nullptr) {
MS_LOG(ERROR) << "malloc nhwc4 input failed."; MS_LOG(ERROR) << "malloc matmul_packed_input_ failed.";
return RET_ERROR; return RET_ERROR;
} }
packed_input_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(tmp_unit * thread_count_ * tile_num_));
size_t tmp_dst_size = thread_count_ * tile_num_ * conv_param_->output_channel_ * sizeof(int32_t); if (packed_input_ == nullptr) {
tmp_dst_ = reinterpret_cast<int32_t *>(ctx_->allocator->Malloc(tmp_dst_size)); MS_LOG(ERROR) << "malloc packed_input_ failed.";
if (tmp_dst_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp_dst_ failed.";
return RET_ERROR;
}
tmp_out_ =
reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(thread_count_ * tile_num_ * conv_param_->output_channel_));
if (tmp_out_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp_out_ failed.";
return RET_ERROR; return RET_ERROR;
} }
return RET_OK; return RET_OK;
@ -321,8 +304,9 @@ int ConvolutionInt8CPUKernel::RunImpl(int task_id) {
auto ori_input_data = reinterpret_cast<int8_t *>(input_tensor->MutableData()); auto ori_input_data = reinterpret_cast<int8_t *>(input_tensor->MutableData());
auto output_addr = reinterpret_cast<int8_t *>(out_tensors_.at(kOutputIndex)->MutableData()); auto output_addr = reinterpret_cast<int8_t *>(out_tensors_.at(kOutputIndex)->MutableData());
if (support_optimize_) { if (support_optimize_) {
ConvInt8Opt(ori_input_data, packed_input_, packed_weight_, reinterpret_cast<int32_t *>(bias_data_), tmp_dst_, ConvInt8Opt(ori_input_data, packed_input_, matmul_packed_input_, packed_weight_,
tmp_out_, output_addr, input_sum_, task_id, conv_param_, gemm_func_); reinterpret_cast<int32_t *>(bias_data_), output_addr, filter_zp_ptr_, input_sum_, task_id, conv_param_,
matmul_func_);
} else { } else {
ConvInt8(ori_input_data, packed_input_, packed_weight_, reinterpret_cast<int32_t *>(bias_data_), tmp_dst_, tmp_out_, ConvInt8(ori_input_data, packed_input_, packed_weight_, reinterpret_cast<int32_t *>(bias_data_), tmp_dst_, tmp_out_,
output_addr, input_sum_, task_id, conv_param_); output_addr, input_sum_, task_id, conv_param_);
@ -346,11 +330,19 @@ int ConvolutionInt8CPUKernel::Run() {
MS_LOG(ERROR) << "Prepare failed."; MS_LOG(ERROR) << "Prepare failed.";
return RET_ERROR; return RET_ERROR;
} }
// init tmp input, output
ret = InitTmpBuffer(); if (support_optimize_) {
if (ret != RET_OK) { ret = InitTmpBufferOpt();
MS_LOG(ERROR) << "Init tmp buffer failed."; if (ret != RET_OK) {
return RET_ERROR; MS_LOG(ERROR) << "Init tmp buffer failed.";
return RET_ERROR;
}
} else {
ret = InitTmpBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init tmp buffer failed.";
return RET_ERROR;
}
} }
int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionInt8Impl, this, thread_count_); int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionInt8Impl, this, thread_count_);

View File

@ -40,6 +40,10 @@ class ConvolutionInt8CPUKernel : public ConvolutionBaseCPUKernel {
free(input_sum_); free(input_sum_);
input_sum_ = nullptr; input_sum_ = nullptr;
} }
if (filter_zp_ptr_ != nullptr) {
free(filter_zp_ptr_);
filter_zp_ptr_ = nullptr;
}
} }
int Init() override; int Init() override;
@ -58,6 +62,10 @@ class ConvolutionInt8CPUKernel : public ConvolutionBaseCPUKernel {
ctx_->allocator->Free(packed_input_); ctx_->allocator->Free(packed_input_);
packed_input_ = nullptr; packed_input_ = nullptr;
} }
if (matmul_packed_input_ != nullptr) {
ctx_->allocator->Free(matmul_packed_input_);
matmul_packed_input_ = nullptr;
}
if (tmp_dst_ != nullptr) { if (tmp_dst_ != nullptr) {
ctx_->allocator->Free(tmp_dst_); ctx_->allocator->Free(tmp_dst_);
tmp_dst_ = nullptr; tmp_dst_ = nullptr;
@ -70,10 +78,12 @@ class ConvolutionInt8CPUKernel : public ConvolutionBaseCPUKernel {
bool support_optimize_ = true; bool support_optimize_ = true;
int8_t *packed_weight_ = nullptr; int8_t *packed_weight_ = nullptr;
int8_t *packed_input_ = nullptr; int8_t *packed_input_ = nullptr;
int8_t *matmul_packed_input_ = nullptr;
int32_t *filter_zp_ptr_ = nullptr; /* per-oc */
int32_t *input_sum_ = nullptr; int32_t *input_sum_ = nullptr;
int32_t *tmp_dst_ = nullptr; int32_t *tmp_dst_ = nullptr;
int8_t *tmp_out_ = nullptr; int8_t *tmp_out_ = nullptr;
GEMM_FUNC gemm_func_ = nullptr; MATMUL_OPT_R_FUNC matmul_func_ = nullptr;
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel

View File

@ -483,9 +483,9 @@ function Run_arm64() {
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'.ms --warmUpLoopCount=1 --loopCount=2' >> adb_run_cmd.txt echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'.ms --warmUpLoopCount=1 --loopCount=2' >> adb_run_cmd.txt
adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_arm64_log_file}" adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_arm64_log_file}"
if [ $? = 0 ]; then if [ $? = 0 ]; then
run_result='arm64: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file} run_result='arm64_awq: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file}
else else
run_result='arm64: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1 run_result='arm64_awq: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1
fi fi
done < ${models_tflite_awaretraining_config} done < ${models_tflite_awaretraining_config}