!7355 [MS][LITE][CPU]replace int8 common conv with int8 matmul

Merge pull request !7355 from fuzhiye/tmp
This commit is contained in:
mindspore-ci-bot 2020-10-19 14:02:04 +08:00 committed by Gitee
commit da43d8e47d
7 changed files with 133 additions and 193 deletions

View File

@ -313,9 +313,9 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, c
}
}
void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, const int32_t *bias_data,
int32_t *tmp_dst, int8_t *tmp_out, int8_t *output_data, int32_t *input_sum, int task_id,
ConvParameter *conv_param, GEMM_FUNC gemm_func) {
void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int8_t *packed_weight,
const int32_t *bias_data, int8_t *output_data, int32_t *filter_zp, int32_t *input_sum, int task_id,
ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func) {
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int in_batch = conv_param->input_batch_;
@ -325,20 +325,22 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight
int out_h = conv_param->output_h_;
int out_w = conv_param->output_w_;
int out_channel = conv_param->output_channel_;
int oc4 = UP_DIV(out_channel, C4NUM);
int32_t input_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_;
int oc8 = UP_DIV(out_channel, C8NUM);
int tile_n = conv_param->tile_num_;
int thread_count = conv_param->thread_num_;
int output_count = out_h * out_w;
int output_tile_count = UP_DIV(output_count, tile_n);
int ic4 = UP_DIV(in_channel, C4NUM);
int kernel_plane = kernel_h * kernel_w;
int unit_size = kernel_plane * ic4 * C4NUM;
int unit_size = UP_ROUND(kernel_plane * in_channel, C4NUM);
int input_sum_offset;
bool per_channel;
if (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL) {
input_sum_offset = tile_n * oc4 * C4NUM;
input_sum_offset = tile_n * oc8 * C8NUM;
per_channel = true;
} else {
input_sum_offset = tile_n;
per_channel = false;
}
for (int b = 0; b < in_batch; b++) {
@ -349,27 +351,18 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight
int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n;
int32_t *tmp_input_sum = input_sum + task_id * input_sum_offset;
int8_t *gemm_input = packed_input + task_id * unit_size * tile_n;
// clear tmp buffer before compute
memset(gemm_input, (int8_t)input_zp, unit_size * tile_n);
int8_t *matmul = matmul_input + task_id * kernel_plane * in_channel * tile_n;
memset(matmul, conv_param->conv_quant_arg_.input_quant_args_[0].zp_, kernel_plane * in_channel * tile_n);
Im2ColPackUnitInt8Opt(input_data + in_batch_offset, gemm_input, matmul, real_cal_num, start_index, filter_zp,
tmp_input_sum, conv_param, per_channel);
int out_offset = thread_id * tile_n * out_channel + out_batch_offset;
size_t tmp_dst_size = tile_n * conv_param->output_channel_ * sizeof(int32_t);
int tmp_dst_offset = task_id * tile_n * conv_param->output_channel_;
memset(tmp_dst + tmp_dst_offset, 0, tmp_dst_size);
Im2ColPackUnitInt8Opt(input_data + in_batch_offset, gemm_input, real_cal_num, start_index, tmp_input_sum,
conv_param);
if (real_cal_num == tile_n) {
int8_t *gemm_output = output_data + out_offset;
IndirectGemmInt8Opt(gemm_output, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4,
kernel_plane, out_channel, tmp_input_sum, conv_param, gemm_func);
} else {
// res part
int8_t *tmp_out_ptr = tmp_out + task_id * tile_n * out_channel;
IndirectGemmInt8Opt(tmp_out_ptr, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4,
kernel_plane, out_channel, tmp_input_sum, conv_param, gemm_func);
memcpy(output_data + out_offset, tmp_out_ptr, real_cal_num * out_channel);
}
matmul_func(gemm_input, packed_weight, gemm_output, real_cal_num, out_channel, unit_size, out_channel,
tmp_input_sum, bias_data, conv_param->conv_quant_arg_.left_shift_,
conv_param->conv_quant_arg_.right_shift_, conv_param->conv_quant_arg_.quant_multiplier_,
conv_param->conv_quant_arg_.output_quant_args_[0].zp_, conv_param->conv_quant_arg_.out_act_min_[0],
conv_param->conv_quant_arg_.out_act_max_[0], per_channel);
}
}
}
@ -708,6 +701,13 @@ void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *
pack_ic += 1;
}
for (int ici = input_channel; ici < ic4; ici += 1) {
for (int i = 0; i < C8NUM; i++) {
pack_ic[i * C4NUM] = 0;
}
pack_ic += 1;
}
for (int oci = 0; oci < oc_8div; oci += C8NUM) {
for (int ri = 0; ri < C8NUM; ri++) {
input_sum_oc[ri * C8NUM + 0] = tmp_sum_value[ri] * filter_zp[oci + 0];
@ -975,6 +975,13 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i
pack_ic += 1;
}
for (int ici = input_channel; ici < ic4; ici += 1) {
for (int i = 0; i < C8NUM; i++) {
pack_ic[i * C4NUM] = 0;
}
pack_ic += 1;
}
for (int i = 0; i < C8NUM; i++) {
input_sum_r[i] = tmp_sum_value[i] * filter_zp;
}

View File

@ -49,9 +49,9 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, c
int32_t *tmp_dst, int8_t *tmp_out, int8_t *output_data, int32_t *input_sum, int task_id,
ConvParameter *conv_param);
void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, const int32_t *bias_data,
int32_t *tmp_dst, int8_t *tmp_out, int8_t *output_data, int32_t *input_sum, int task_id,
ConvParameter *conv_param, GEMM_FUNC gemm_func);
void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int8_t *packed_weight,
const int32_t *bias_data, int8_t *output_data, int32_t *filter_zp, int32_t *input_sum, int task_id,
ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func);
// int8 convolution 1x1
void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,

View File

@ -17,6 +17,7 @@
#include "nnacl/pack.h"
#include <string.h>
#include <stdlib.h>
#include "nnacl/int8/conv_int8.h"
void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel) {
return PackNCHWToNHWCFp32(src, dst, 1, plane, channel);
@ -80,54 +81,6 @@ void PackWeightInt8(int8_t *weight_data, ConvParameter *conv_param, int8_t *pack
} // kernel plane loop
}
void PackWeightInt8Opt(int8_t *weight_data, ConvParameter *conv_param, int8_t *packed_weight, int32_t *weight_sum) {
// original weight format : ohwi
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int in_channel = conv_param->input_channel_;
int out_channel = conv_param->output_channel_;
int oc4 = UP_DIV(out_channel, C4NUM);
int ic4 = UP_DIV(in_channel, C4NUM);
int kernel_plane = kernel_h * kernel_w;
int pack_weight_size = oc4 * ic4 * C4NUM * C4NUM * kernel_plane;
int unit_size = C4NUM * C4NUM;
int block_size = pack_weight_size / oc4;
QuantArg *filter_args = conv_param->conv_quant_arg_.filter_quant_args_;
for (int m = 0; m < kernel_plane; m++) {
int kernel_plane_stride = m * in_channel;
int packed_kernel_plane_stride = m * unit_size * ic4;
for (int i = 0; i < ic4; i++) {
int channel_block_stride = kernel_plane_stride + i * C4NUM;
int packed_channel_block_size = packed_kernel_plane_stride + i * unit_size;
int ic_remainder = in_channel - i * C4NUM;
int real_ic_num = ic_remainder < C4NUM ? ic_remainder : C4NUM;
for (int h = 0; h < real_ic_num; h++) {
int block_stride = channel_block_stride + h;
int packed_block_stride = packed_channel_block_size + h;
for (int j = 0; j < oc4; j++) {
int kernel_block_stride = block_stride + j * C4NUM * kernel_plane * in_channel;
int packed_kernel_block_size = packed_block_stride + j * block_size;
int oc_remainder = out_channel - j * C4NUM;
int real_oc_num = oc_remainder < C4NUM ? oc_remainder : C4NUM;
for (int k = 0; k < real_oc_num; k++) {
int8_t *origin_data_ptr = weight_data + kernel_block_stride + k * kernel_plane * in_channel;
int8_t *packed_data_ptr = packed_weight + packed_kernel_block_size + k * C4NUM;
*packed_data_ptr = origin_data_ptr[0];
int32_t f_zp;
if (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL) {
f_zp = filter_args[j * C4NUM + k].zp_;
} else {
f_zp = filter_args[0].zp_;
}
weight_sum[j * C4NUM + k] += (int32_t)(packed_data_ptr[0] - f_zp);
}
} // kernel block loop
} // inchannel block loop
} // channel block loop
} // kernel plane loop
}
void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size) {
/* support nhwc */
char *src = (char *)src_ptr;
@ -391,11 +344,10 @@ void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real
} // tile num loop
}
void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index,
int32_t *input_sum, ConvParameter *conv_param) {
void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int real_cal_num,
int block_index, int32_t *filter_zp, int32_t *input_sum, ConvParameter *conv_param,
bool per_channel) {
// input format : nhwc
int tile_num = conv_param->tile_num_;
QuantArg *filter_arg = conv_param->conv_quant_arg_.filter_quant_args_;
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int stride_h = conv_param->stride_h_;
@ -407,11 +359,8 @@ void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int r
int in_channel = conv_param->input_channel_;
int in_h = conv_param->input_h_;
int in_w = conv_param->input_w_;
int ic4_minus = in_channel / C4NUM;
int ic4 = UP_DIV(in_channel, C4NUM);
int oc4 = UP_DIV(conv_param->output_channel_, C4NUM);
int out_w = conv_param->output_w_;
int block_size = kernel_h * kernel_w;
int kernel_plane = kernel_h * kernel_w;
for (int i = 0; i < real_cal_num; i++) {
int block_start = block_index + i;
@ -422,47 +371,30 @@ void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int r
int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h));
int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w));
int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w));
if (dilation_w == 1 && dilation_h == 1) {
for (int j = kh_s; j < kh_e; j++) {
int input_y_stride = j * in_w * in_channel + input_stride;
int input_x_stride = input_y_stride + kw_s * in_channel;
int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane;
memcpy(matmul_input + input_plane_offset, input_data + input_x_stride, (kw_e - kw_s) * in_channel);
} // kernel_h loop
} else {
for (int j = kh_s; j < kh_e; j++) {
int input_y_stride = j * dilation_h * in_w * in_channel + input_stride;
for (int n = kw_s; n < kw_e; n++) {
int input_x_stride = input_y_stride + n * dilation_w * in_channel;
int input_plane_offset = (j * kernel_w + n) * tile_num * C4NUM * ic4 + i * C4NUM;
for (int m = 0; m < ic4_minus; m++) {
int channel_block_stride = input_x_stride + m * C4NUM;
int channel_block_offset = input_plane_offset + m * tile_num * C4NUM;
memcpy(packed_input + channel_block_offset, input_data + channel_block_stride, 4);
} // channel_block loop
int ic_res = conv_param->input_channel_ - ic4_minus * C4NUM;
for (int l = 0; l < ic_res; ++l) {
int channel_block_stride = input_x_stride + ic4_minus * C4NUM + l;
int channel_block_offset = input_plane_offset + ic4_minus * tile_num * C4NUM + l;
packed_input[channel_block_offset] = input_data[channel_block_stride];
for (int k = kw_s; k < kw_e; ++k) {
int input_x_stride = input_y_stride + k * dilation_w * in_channel;
int input_plane_offset = (j * kernel_w + k) * in_channel + i * in_channel * kernel_plane;
memcpy(matmul_input + input_plane_offset, input_data + input_x_stride, in_channel);
}
} // kernel_w loop
} // kernel_h loop
int32_t input_accumulator = 0;
for (int j = 0; j < block_size; j++) {
int block_offset = j * tile_num * ic4 * C4NUM + i * C4NUM;
for (int c = 0; c < ic4; c++) {
int ic4_offset = block_offset + c * tile_num * C4NUM;
for (int k = 0; k < C4NUM; ++k) {
input_accumulator += (packed_input + ic4_offset)[k];
}
}
}
if (!(conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC)) {
continue;
} else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
int cal_num_offset = i * oc4 * C4NUM;
for (int l = 0; l < conv_param->output_channel_; ++l) {
input_sum[cal_num_offset + l] = input_accumulator * filter_arg[l].zp_;
}
} else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
!(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
input_sum[i] = input_accumulator * filter_arg[0].zp_;
}
} // tile num loop
if (per_channel) {
Conv1x1PreOptPeroc(matmul_input, packed_input, input_sum, kernel_plane * in_channel, conv_param->output_channel_,
real_cal_num, filter_zp, C8NUM * C8NUM);
} else {
Conv1x1PreOptPert(matmul_input, packed_input, input_sum, kernel_plane * in_channel, real_cal_num, conv_param);
}
}
void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, ConvParameter *conv_param) {

View File

@ -35,8 +35,9 @@ void PackHWCToWHC(const float *src, float *dst, int height, int width, int chann
void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index,
int32_t *input_sum, ConvParameter *conv_param);
void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index,
int32_t *input_sum, ConvParameter *conv_param);
void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int real_cal_num,
int block_index, int32_t *filter_zp, int32_t *input_sum, ConvParameter *conv_param,
bool per_channel);
void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16);
@ -52,8 +53,6 @@ void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel);
void PackWeightInt8(int8_t *weight_data, ConvParameter *conv_param, int8_t *packed_weight, int32_t *weight_sum);
void PackWeightInt8Opt(int8_t *weight_data, ConvParameter *conv_param, int8_t *packed_weight, int32_t *weight_sum);
void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight_data, ConvParameter *conv_param);
void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel);

View File

@ -32,7 +32,8 @@ using mindspore::schema::PrimitiveType_Conv2D;
namespace mindspore::kernel {
void ConvolutionInt8CPUKernel::CheckSupportOptimize() {
tile_num_ = 24;
tile_num_ = 8;
matmul_func_ = MatMulInt8_8x8_r;
#ifdef ENABLE_ARM32
tile_num_ = 2;
support_optimize_ = false;
@ -42,19 +43,19 @@ void ConvolutionInt8CPUKernel::CheckSupportOptimize() {
void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_;
if (optimize_op_handler != nullptr) {
dlerror();
*(reinterpret_cast<void **>(&gemm_func_)) = dlsym(optimize_op_handler, "IndirectGemmInt8_optimize_handler");
*(reinterpret_cast<void **>(&matmul_func_)) = dlsym(optimize_op_handler, "MatMulRInt8_optimize_handler");
auto dlopen_error = dlerror();
if (dlopen_error != nullptr) {
MS_LOG(ERROR) << "load gemm func failed! " << dlopen_error << ".";
tile_num_ = 4;
MS_LOG(ERROR) << "load matmul func failed! " << dlopen_error << ".";
support_optimize_ = false;
gemm_func_ = nullptr;
matmul_func_ = nullptr;
} else {
// do nothing
support_optimize_ = true;
}
} else {
tile_num_ = 4;
support_optimize_ = false;
matmul_func_ = nullptr;
}
#endif
conv_param_->tile_num_ = tile_num_;
@ -141,7 +142,6 @@ int ConvolutionInt8CPUKernel::InitWeightBias() {
int ConvolutionInt8CPUKernel::InitTmpBuffer() {
MS_ASSERT(ctx_->allocator != nullptr);
int ic4 = UP_DIV(conv_param_->input_channel_, C4NUM);
int kernel_plane = conv_param_->kernel_h_ * conv_param_->kernel_w_;
int plane_c4 = UP_DIV(kernel_plane, C4NUM);
@ -176,11 +176,10 @@ int ConvolutionInt8CPUKernel::InitWeightBiasOpt() {
int kernel_w = filter_tensor->Width();
conv_param_->input_channel_ = input_channel;
conv_param_->output_channel_ = output_channel;
int ic4 = UP_DIV(input_channel, C4NUM);
int oc4 = UP_DIV(output_channel, C4NUM);
int oc8 = UP_DIV(output_channel, C8NUM);
int kernel_plane = kernel_h * kernel_w;
int pack_weight_size = oc4 * ic4 * C4NUM * C4NUM * kernel_plane;
auto filter_arg = conv_param_->conv_quant_arg_.filter_quant_args_;
int up_round_deep = UP_ROUND(kernel_plane * input_channel, C4NUM);
int pack_weight_size = oc8 * C8NUM * up_round_deep;
int32_t input_zp = conv_param_->conv_quant_arg_.input_quant_args_[0].zp_;
// init weight
@ -191,27 +190,15 @@ int ConvolutionInt8CPUKernel::InitWeightBiasOpt() {
return RET_ERROR;
}
memset(packed_weight_, 0, pack_weight_size);
auto *weight_sum = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t) * output_channel));
if (weight_sum == nullptr) {
MS_LOG(ERROR) << "malloc weight_sum failed.";
return RET_ERROR;
}
for (int i = 0; i < output_channel; i++) {
if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) {
weight_sum[i] = ic4 * C4NUM * kernel_plane * filter_arg[i].zp_;
} else {
weight_sum[i] = ic4 * C4NUM * kernel_plane * filter_arg[0].zp_;
}
}
PackWeightInt8Opt(origin_weight, conv_param_, packed_weight_, weight_sum);
RowMajor2Row8x4MajorInt8(origin_weight, packed_weight_, output_channel, input_channel * kernel_plane);
// init bias
bias_data_ = reinterpret_cast<int32_t *>(malloc(oc4 * C4NUM * sizeof(int32_t)));
bias_data_ = reinterpret_cast<int32_t *>(malloc(oc8 * C8NUM * sizeof(int32_t)));
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "malloc bias_data_ failed.";
return RET_ERROR;
}
memset(bias_data_, 0, oc4 * C4NUM * sizeof(int32_t));
memset(bias_data_, 0, oc8 * C8NUM * sizeof(int32_t));
if (in_tensors_.size() == kInputSize2) {
auto ori_bias = reinterpret_cast<int32_t *>(in_tensors_.at(kBiasIndex)->MutableData());
memcpy(bias_data_, ori_bias, output_channel * sizeof(int32_t));
@ -219,21 +206,26 @@ int ConvolutionInt8CPUKernel::InitWeightBiasOpt() {
MS_ASSERT(in_tensors_.size() == kInputSize1);
}
auto *bias_data = reinterpret_cast<int32_t *>(bias_data_);
int c4_kernel_plane_size = kernel_plane * ic4 * C4NUM;
if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) {
for (int i = 0; i < output_channel; i++) {
bias_data[i] += filter_arg[i].zp_ * input_zp * c4_kernel_plane_size - weight_sum[i] * input_zp;
bool filter_peroc = conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL;
if (filter_peroc) {
filter_zp_ptr_ = reinterpret_cast<int32_t *>(malloc(output_channel * sizeof(int32_t)));
}
} else {
for (int i = 0; i < output_channel; i++) {
bias_data[i] += filter_arg[0].zp_ * input_zp * c4_kernel_plane_size - weight_sum[i] * input_zp;
for (int oc = 0; oc < output_channel; oc++) {
int32_t filter_zp = conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_;
if (filter_peroc) {
filter_zp = conv_param_->conv_quant_arg_.filter_quant_args_[oc].zp_;
filter_zp_ptr_[oc] = filter_zp;
}
int32_t weight_sum_value = up_round_deep * filter_zp;
for (int i = 0; i < kernel_plane * input_channel; i++) {
weight_sum_value += origin_weight[oc * kernel_plane * input_channel + i] - filter_zp;
}
bias_data[oc] += filter_zp * input_zp * up_round_deep - weight_sum_value * input_zp;
}
free(weight_sum);
size_t input_sum_size;
if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) {
input_sum_size = oc4 * C4NUM * tile_num_ * thread_count_ * sizeof(int32_t);
input_sum_size = oc8 * C8NUM * tile_num_ * thread_count_ * sizeof(int32_t);
} else {
input_sum_size = tile_num_ * thread_count_ * sizeof(int32_t);
}
@ -248,26 +240,17 @@ int ConvolutionInt8CPUKernel::InitWeightBiasOpt() {
int ConvolutionInt8CPUKernel::InitTmpBufferOpt() {
MS_ASSERT(ctx_->allocator != nullptr);
int ic4 = UP_DIV(conv_param_->input_channel_, C4NUM);
size_t nhwc4_input_size = ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_;
nhwc4_input_ = ctx_->allocator->Malloc(nhwc4_input_size);
if (nhwc4_input_ == nullptr) {
MS_LOG(ERROR) << "malloc nhwc4 input failed.";
int kernel_plane = conv_param_->kernel_h_ * conv_param_->kernel_w_;
int tmp_unit = UP_ROUND(kernel_plane * conv_param_->input_channel_, C4NUM);
matmul_packed_input_ = reinterpret_cast<int8_t *>(
ctx_->allocator->Malloc(thread_count_ * tile_num_ * kernel_plane * conv_param_->input_channel_));
if (matmul_packed_input_ == nullptr) {
MS_LOG(ERROR) << "malloc matmul_packed_input_ failed.";
return RET_ERROR;
}
size_t tmp_dst_size = thread_count_ * tile_num_ * conv_param_->output_channel_ * sizeof(int32_t);
tmp_dst_ = reinterpret_cast<int32_t *>(ctx_->allocator->Malloc(tmp_dst_size));
if (tmp_dst_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp_dst_ failed.";
return RET_ERROR;
}
tmp_out_ =
reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(thread_count_ * tile_num_ * conv_param_->output_channel_));
if (tmp_out_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp_out_ failed.";
packed_input_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(tmp_unit * thread_count_ * tile_num_));
if (packed_input_ == nullptr) {
MS_LOG(ERROR) << "malloc packed_input_ failed.";
return RET_ERROR;
}
return RET_OK;
@ -321,8 +304,9 @@ int ConvolutionInt8CPUKernel::RunImpl(int task_id) {
auto ori_input_data = reinterpret_cast<int8_t *>(input_tensor->MutableData());
auto output_addr = reinterpret_cast<int8_t *>(out_tensors_.at(kOutputIndex)->MutableData());
if (support_optimize_) {
ConvInt8Opt(ori_input_data, packed_input_, packed_weight_, reinterpret_cast<int32_t *>(bias_data_), tmp_dst_,
tmp_out_, output_addr, input_sum_, task_id, conv_param_, gemm_func_);
ConvInt8Opt(ori_input_data, packed_input_, matmul_packed_input_, packed_weight_,
reinterpret_cast<int32_t *>(bias_data_), output_addr, filter_zp_ptr_, input_sum_, task_id, conv_param_,
matmul_func_);
} else {
ConvInt8(ori_input_data, packed_input_, packed_weight_, reinterpret_cast<int32_t *>(bias_data_), tmp_dst_, tmp_out_,
output_addr, input_sum_, task_id, conv_param_);
@ -346,12 +330,20 @@ int ConvolutionInt8CPUKernel::Run() {
MS_LOG(ERROR) << "Prepare failed.";
return RET_ERROR;
}
// init tmp input, output
if (support_optimize_) {
ret = InitTmpBufferOpt();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init tmp buffer failed.";
return RET_ERROR;
}
} else {
ret = InitTmpBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init tmp buffer failed.";
return RET_ERROR;
}
}
int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionInt8Impl, this, thread_count_);
if (error_code != RET_OK) {

View File

@ -40,6 +40,10 @@ class ConvolutionInt8CPUKernel : public ConvolutionBaseCPUKernel {
free(input_sum_);
input_sum_ = nullptr;
}
if (filter_zp_ptr_ != nullptr) {
free(filter_zp_ptr_);
filter_zp_ptr_ = nullptr;
}
}
int Init() override;
@ -58,6 +62,10 @@ class ConvolutionInt8CPUKernel : public ConvolutionBaseCPUKernel {
ctx_->allocator->Free(packed_input_);
packed_input_ = nullptr;
}
if (matmul_packed_input_ != nullptr) {
ctx_->allocator->Free(matmul_packed_input_);
matmul_packed_input_ = nullptr;
}
if (tmp_dst_ != nullptr) {
ctx_->allocator->Free(tmp_dst_);
tmp_dst_ = nullptr;
@ -70,10 +78,12 @@ class ConvolutionInt8CPUKernel : public ConvolutionBaseCPUKernel {
bool support_optimize_ = true;
int8_t *packed_weight_ = nullptr;
int8_t *packed_input_ = nullptr;
int8_t *matmul_packed_input_ = nullptr;
int32_t *filter_zp_ptr_ = nullptr; /* per-oc */
int32_t *input_sum_ = nullptr;
int32_t *tmp_dst_ = nullptr;
int8_t *tmp_out_ = nullptr;
GEMM_FUNC gemm_func_ = nullptr;
MATMUL_OPT_R_FUNC matmul_func_ = nullptr;
};
} // namespace mindspore::kernel

View File

@ -483,9 +483,9 @@ function Run_arm64() {
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'.ms --warmUpLoopCount=1 --loopCount=2' >> adb_run_cmd.txt
adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_arm64_log_file}"
if [ $? = 0 ]; then
run_result='arm64: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file}
run_result='arm64_awq: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file}
else
run_result='arm64: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1
run_result='arm64_awq: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1
fi
done < ${models_tflite_awaretraining_config}