improve code nnacl/int8

This commit is contained in:
xuanyue 2021-06-21 20:17:24 +08:00
parent 2d8e44f3d0
commit a9ae754a19
13 changed files with 35 additions and 29 deletions

View File

@ -452,7 +452,8 @@ void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weigh
}
void Conv3x3Int8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound,
bool w_not_bound, int output_w, int real_num, int oc_start, ConvParameter *conv_param) {
bool w_not_bound, int output_w, int real_num, int oc_start,
const ConvParameter *conv_param) {
int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_;
int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_;
int32_t *quant_multiplier = conv_param->conv_quant_arg_.quant_multiplier_;
@ -745,7 +746,7 @@ void Conv3x3Int8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, in
}
void Conv3x3Int8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const int32_t *bias_data, int start_index,
int real_cal_num, int out_w_block, ConvParameter *conv_param) {
int real_cal_num, int out_w_block, const ConvParameter *conv_param) {
int output_channel = conv_param->output_channel_;
int output_w = conv_param->output_w_;
int output_h = conv_param->output_h_;
@ -778,7 +779,7 @@ void Conv3x3Int8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const
}
void Conv3x3Int8InputTransform(const int16_t *input_data, int16_t *trans_input, int16_t *tmp_data, int start_index,
int real_cal_num, int out_w_block, ConvParameter *conv_param) {
int real_cal_num, int out_w_block, const ConvParameter *conv_param) {
// input data format : nhwc
int input_channel = conv_param->input_channel_;
int input_width = conv_param->input_w_;
@ -868,7 +869,7 @@ void Conv3x3Int8Gemm(int32_t *dst, const int16_t *src, const int16_t *weight, in
// int8 convolution 3x3
void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data,
int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out,
int task_id, ConvParameter *conv_param) {
int task_id, const ConvParameter *conv_param) {
int ic8 = UP_DIV(conv_param->input_channel_, C8NUM);
int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT);
int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT);

View File

@ -39,7 +39,7 @@ void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weigh
void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data,
int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out,
int task_id, ConvParameter *conv_param);
int task_id, const ConvParameter *conv_param);
#ifdef __cplusplus
}

View File

@ -154,8 +154,8 @@ void ConvDw3x3Int8InitBuffer(int8_t *buffer, const int8_t *input, const ConvPara
void ConvDw3x3Int8Window(int8_t *output, const int8_t *buffer, const int16_t *weight, const int32_t *bias, int col_size,
int row_size, int channel, int output_h, int output_w, int8_t in_zp, int32_t out_zp,
int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, int32_t acc_min,
int32_t acc_max, int stride, bool per_channel) {
const int32_t *out_multiplier, const int32_t *left_shift, const int32_t *right_shift,
int32_t acc_min, int32_t acc_max, int stride, bool per_channel) {
for (int w = 0; w < output_w; w++) {
int tmp_buffer[C8NUM];
for (int i = 0; i < C8NUM; i++) {
@ -330,8 +330,8 @@ void ConvDw3x3Int8(int8_t *output_data, int8_t *buffer, const int8_t *input_data
#ifndef ENABLE_ARM32
void ConvDw3x3Int8BorderPixel(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height,
int width, int in_kh_step, int in_kw_step, int channel, int8_t in_zp, int32_t out_zp,
int *out_multiplier, int *left_shift, int *right_shift, int32_t acc_min, int32_t acc_max,
bool per_channel) {
const int *out_multiplier, const int *left_shift, const int *right_shift, int32_t acc_min,
int32_t acc_max, bool per_channel) {
for (int c = 0; c < channel; c += 8) {
int tmp_buffer[8];
for (int i = 0; i < 8; i++) {
@ -563,8 +563,9 @@ void ConvDwInt8Border(int8_t *dst, const int8_t *src, const int16_t *weight, con
#ifndef ENABLE_ARM
void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, int width,
int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step,
int in_kh_step, int in_kw_step, int8_t *in_zp, int32_t *out_zp, int32_t *out_multiplier,
int32_t *left_shift, int32_t *right_shift, int32_t *acc_min, int32_t *acc_max) {
int in_kh_step, int in_kw_step, const int8_t *in_zp, const int32_t *out_zp,
const int32_t *out_multiplier, const int32_t *left_shift, const int32_t *right_shift,
const int32_t *acc_min, const int32_t *acc_max) {
int tmp_buffer[C8NUM];
int8_t *dst_h = dst;
const int8_t *src_h = src;

View File

@ -18,7 +18,7 @@
#include "nnacl/int8/matmul_int8.h"
#include "nnacl/int8/common_func_int8.h"
int DeConvPostInt8C4(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel,
ConvParameter *conv_param) {
const ConvParameter *conv_param) {
/* row4x4-major(ih*iw x oc*kh*kw) -> row4-major(oh*ow x oc) */
size_t input_plane = conv_param->input_w_ * conv_param->input_h_;
size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_;
@ -107,8 +107,8 @@ void DeConvWeightTransInt8(const int8_t *src, int8_t *dst, int input_channel, in
return;
}
void DeConvPackWeightSum(int8_t *weight, int32_t *weight_sum, int32_t input_zp, int32_t filter_zp, int deep, int col4,
bool suppport_opt) {
void DeConvPackWeightSum(const int8_t *weight, int32_t *weight_sum, int32_t input_zp, int32_t filter_zp, int deep,
int col4, bool suppport_opt) {
int deep16 = UP_ROUND(deep, C16NUM);
for (int c = 0; c < col4; c++) {
int c4div = c / C4NUM, c4mod = c % C4NUM;

View File

@ -27,8 +27,8 @@
#ifdef __cplusplus
extern "C" {
#endif
void DeConvPackWeightSum(int8_t *weight, int32_t *weight_sum, int32_t input_zp, int32_t filter_zp, int deep16, int col4,
bool suppport_opt);
void DeConvPackWeightSum(const int8_t *weight, int32_t *weight_sum, int32_t input_zp, int32_t filter_zp, int deep16,
int col4, bool suppport_opt);
void DeConvPackInputSum(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16,
bool suppport_opt);
void DeConvWeightTransInt8(const int8_t *src, int8_t *dst, int input_channel, int output_channel, int plane,

View File

@ -17,7 +17,7 @@
#include "nnacl/int8/layer_norm_int8.h"
void LayerNormGammaAndBetaInt8(int8_t *dst, const int8_t *src, const float *gamma_data, const float *beta_data,
LayerNormQuantArg *quant, int num, const float mean, const float deno) {
const LayerNormQuantArg *quant, int num, const float mean, const float deno) {
for (int i = 0; i < num; i++) {
float fp32_src = (src[i] - quant->in_zp_) * quant->in_scale_;
float fp32_dst = (fp32_src - mean) * deno;
@ -33,7 +33,7 @@ void LayerNormGammaAndBetaInt8(int8_t *dst, const int8_t *src, const float *gamm
*
* */
int LayerNormInt8(const int8_t *src_data, const float *gamma_data, const float *beta_data, int8_t *dst_data,
LayerNormParameter *param, LayerNormQuantArg *quant, int task_id) {
const LayerNormParameter *param, const LayerNormQuantArg *quant, int task_id) {
if (src_data == NULL || dst_data == NULL || gamma_data == NULL || beta_data == NULL) {
return NNACL_NULL_PTR;
}

View File

@ -25,7 +25,7 @@ extern "C" {
#endif
int LayerNormInt8(const int8_t *src_data, const float *gamma_data, const float *beta_data, int8_t *dst_data,
LayerNormParameter *param, LayerNormQuantArg *quant, int task_id);
const LayerNormParameter *param, const LayerNormQuantArg *quant, int task_id);
#ifdef __cplusplus
}

View File

@ -92,7 +92,7 @@ void RowMajor2Row4x16MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row,
}
}
void RowMajor2Row16x4MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) {
void RowMajor2Row16x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) {
/* Row-major to row16x4-major (block row-major) */
int col16 = UP_ROUND(col, C16NUM);
size_t row_4div = row / C4NUM * C4NUM;
@ -231,8 +231,9 @@ void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row,
#ifndef ENABLE_ARM
void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums,
const int *bias, int mini, int maxi, int out_zp, int32_t *multiplier, int32_t *left_shift,
int32_t *right_shift, size_t stride, size_t filter_peroc, int32_t *filter_zp) {
const int *bias, int mini, int maxi, int out_zp, const int32_t *multiplier,
const int32_t *left_shift, const int32_t *right_shift, size_t stride, size_t filter_peroc,
const int32_t *filter_zp) {
/*
* row4x16-major * row16x4-major => (int8)row-major
* support per-layer && weight per-channel

View File

@ -28,14 +28,15 @@ extern "C" {
/* matmul */
void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16,
const int *input_sum, const int *bias);
void RowMajor2Row16x4MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
void RowMajor2Row16x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
void RowMajor2Col16x4MajorInt8(int8_t *src, int row, int col, int8_t *dst);
void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst, DataOrder order);
void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int *weight_zp_ptr, const int *bias, int *dst,
DataOrder order, bool filter_per_channel);
void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums,
const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift,
int32_t *right_shift, size_t stride, size_t filter_peroc, int32_t *filter_zp);
const int *bias, int act_min, int act_max, int out_zp, const int32_t *multiplier,
const int32_t *left_shift, const int32_t *right_shift, size_t stride, size_t filter_peroc,
const int32_t *filter_zp);
/* 8x4 4x8 -> 8x8 */
/* optimize conv */

View File

@ -16,7 +16,7 @@
#include "nnacl/int8/power_int8.h"
int PowerInt8(const int8_t *input, int8_t *exp_ptr, int8_t *output, int count, PowerParameter *param) {
int PowerInt8(const int8_t *input, const int8_t *exp_ptr, int8_t *output, int count, PowerParameter *param) {
double input_scale = param->quant_arg_.in_args_.scale_;
int input_zp = param->quant_arg_.in_args_.zp_;
double output_scale = param->quant_arg_.out_args_.scale_;

View File

@ -24,7 +24,7 @@
#ifdef __cplusplus
extern "C" {
#endif
int PowerInt8(const int8_t *input_ptr, int8_t *exp_ptr, int8_t *output_ptr, int count, PowerParameter *parameter);
int PowerInt8(const int8_t *input_ptr, const int8_t *exp_ptr, int8_t *output_ptr, int count, PowerParameter *parameter);
#ifdef __cplusplus
}
#endif

View File

@ -17,7 +17,8 @@
#include "nnacl/int8/unsqueeze_int8.h"
#include "nnacl/unsqueeze_parameter.h"
int Int8Unsqueeze(int8_t *input_ptr, int8_t *output_ptr, UnSqueezeParameter *para_, size_t data_size, int task_id) {
int Int8Unsqueeze(const int8_t *input_ptr, int8_t *output_ptr, UnSqueezeParameter *para_, size_t data_size,
int task_id) {
float output_scale = para_->quant_arg.out_quant_args_.scale_;
int8_t output_zp = para_->quant_arg.out_quant_args_.zp_;
float input_scale = para_->quant_arg.in_quant_args_.scale_;

View File

@ -24,7 +24,8 @@
#ifdef __cplusplus
extern "C" {
#endif
int Int8Unsqueeze(int8_t *input_ptr, int8_t *output_ptr, UnSqueezeParameter *para_, size_t data_size, int task_id);
int Int8Unsqueeze(const int8_t *input_ptr, int8_t *output_ptr, UnSqueezeParameter *para_, size_t data_size,
int task_id);
#ifdef __cplusplus
}
#endif