improve code nnacl/int8

This commit is contained in:
xuanyue 2021-06-21 20:17:24 +08:00
parent 2d8e44f3d0
commit a9ae754a19
13 changed files with 35 additions and 29 deletions

View File

@ -452,7 +452,8 @@ void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weigh
} }
void Conv3x3Int8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound, void Conv3x3Int8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound,
bool w_not_bound, int output_w, int real_num, int oc_start, ConvParameter *conv_param) { bool w_not_bound, int output_w, int real_num, int oc_start,
const ConvParameter *conv_param) {
int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_; int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_;
int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_; int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_;
int32_t *quant_multiplier = conv_param->conv_quant_arg_.quant_multiplier_; int32_t *quant_multiplier = conv_param->conv_quant_arg_.quant_multiplier_;
@ -745,7 +746,7 @@ void Conv3x3Int8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, in
} }
void Conv3x3Int8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const int32_t *bias_data, int start_index, void Conv3x3Int8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const int32_t *bias_data, int start_index,
int real_cal_num, int out_w_block, ConvParameter *conv_param) { int real_cal_num, int out_w_block, const ConvParameter *conv_param) {
int output_channel = conv_param->output_channel_; int output_channel = conv_param->output_channel_;
int output_w = conv_param->output_w_; int output_w = conv_param->output_w_;
int output_h = conv_param->output_h_; int output_h = conv_param->output_h_;
@ -778,7 +779,7 @@ void Conv3x3Int8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const
} }
void Conv3x3Int8InputTransform(const int16_t *input_data, int16_t *trans_input, int16_t *tmp_data, int start_index, void Conv3x3Int8InputTransform(const int16_t *input_data, int16_t *trans_input, int16_t *tmp_data, int start_index,
int real_cal_num, int out_w_block, ConvParameter *conv_param) { int real_cal_num, int out_w_block, const ConvParameter *conv_param) {
// input data format : nhwc // input data format : nhwc
int input_channel = conv_param->input_channel_; int input_channel = conv_param->input_channel_;
int input_width = conv_param->input_w_; int input_width = conv_param->input_w_;
@ -868,7 +869,7 @@ void Conv3x3Int8Gemm(int32_t *dst, const int16_t *src, const int16_t *weight, in
// int8 convolution 3x3 // int8 convolution 3x3
void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data, void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data,
int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out, int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out,
int task_id, ConvParameter *conv_param) { int task_id, const ConvParameter *conv_param) {
int ic8 = UP_DIV(conv_param->input_channel_, C8NUM); int ic8 = UP_DIV(conv_param->input_channel_, C8NUM);
int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT); int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT);
int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT); int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT);

View File

@ -39,7 +39,7 @@ void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weigh
void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data, void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data,
int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out, int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out,
int task_id, ConvParameter *conv_param); int task_id, const ConvParameter *conv_param);
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@ -154,8 +154,8 @@ void ConvDw3x3Int8InitBuffer(int8_t *buffer, const int8_t *input, const ConvPara
void ConvDw3x3Int8Window(int8_t *output, const int8_t *buffer, const int16_t *weight, const int32_t *bias, int col_size, void ConvDw3x3Int8Window(int8_t *output, const int8_t *buffer, const int16_t *weight, const int32_t *bias, int col_size,
int row_size, int channel, int output_h, int output_w, int8_t in_zp, int32_t out_zp, int row_size, int channel, int output_h, int output_w, int8_t in_zp, int32_t out_zp,
int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, int32_t acc_min, const int32_t *out_multiplier, const int32_t *left_shift, const int32_t *right_shift,
int32_t acc_max, int stride, bool per_channel) { int32_t acc_min, int32_t acc_max, int stride, bool per_channel) {
for (int w = 0; w < output_w; w++) { for (int w = 0; w < output_w; w++) {
int tmp_buffer[C8NUM]; int tmp_buffer[C8NUM];
for (int i = 0; i < C8NUM; i++) { for (int i = 0; i < C8NUM; i++) {
@ -330,8 +330,8 @@ void ConvDw3x3Int8(int8_t *output_data, int8_t *buffer, const int8_t *input_data
#ifndef ENABLE_ARM32 #ifndef ENABLE_ARM32
void ConvDw3x3Int8BorderPixel(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, void ConvDw3x3Int8BorderPixel(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height,
int width, int in_kh_step, int in_kw_step, int channel, int8_t in_zp, int32_t out_zp, int width, int in_kh_step, int in_kw_step, int channel, int8_t in_zp, int32_t out_zp,
int *out_multiplier, int *left_shift, int *right_shift, int32_t acc_min, int32_t acc_max, const int *out_multiplier, const int *left_shift, const int *right_shift, int32_t acc_min,
bool per_channel) { int32_t acc_max, bool per_channel) {
for (int c = 0; c < channel; c += 8) { for (int c = 0; c < channel; c += 8) {
int tmp_buffer[8]; int tmp_buffer[8];
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
@ -563,8 +563,9 @@ void ConvDwInt8Border(int8_t *dst, const int8_t *src, const int16_t *weight, con
#ifndef ENABLE_ARM #ifndef ENABLE_ARM
void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, int width, void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, int width,
int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step, int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step,
int in_kh_step, int in_kw_step, int8_t *in_zp, int32_t *out_zp, int32_t *out_multiplier, int in_kh_step, int in_kw_step, const int8_t *in_zp, const int32_t *out_zp,
int32_t *left_shift, int32_t *right_shift, int32_t *acc_min, int32_t *acc_max) { const int32_t *out_multiplier, const int32_t *left_shift, const int32_t *right_shift,
const int32_t *acc_min, const int32_t *acc_max) {
int tmp_buffer[C8NUM]; int tmp_buffer[C8NUM];
int8_t *dst_h = dst; int8_t *dst_h = dst;
const int8_t *src_h = src; const int8_t *src_h = src;

View File

@ -18,7 +18,7 @@
#include "nnacl/int8/matmul_int8.h" #include "nnacl/int8/matmul_int8.h"
#include "nnacl/int8/common_func_int8.h" #include "nnacl/int8/common_func_int8.h"
int DeConvPostInt8C4(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel, int DeConvPostInt8C4(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel,
ConvParameter *conv_param) { const ConvParameter *conv_param) {
/* row4x4-major(ih*iw x oc*kh*kw) -> row4-major(oh*ow x oc) */ /* row4x4-major(ih*iw x oc*kh*kw) -> row4-major(oh*ow x oc) */
size_t input_plane = conv_param->input_w_ * conv_param->input_h_; size_t input_plane = conv_param->input_w_ * conv_param->input_h_;
size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_;
@ -107,8 +107,8 @@ void DeConvWeightTransInt8(const int8_t *src, int8_t *dst, int input_channel, in
return; return;
} }
void DeConvPackWeightSum(int8_t *weight, int32_t *weight_sum, int32_t input_zp, int32_t filter_zp, int deep, int col4, void DeConvPackWeightSum(const int8_t *weight, int32_t *weight_sum, int32_t input_zp, int32_t filter_zp, int deep,
bool suppport_opt) { int col4, bool suppport_opt) {
int deep16 = UP_ROUND(deep, C16NUM); int deep16 = UP_ROUND(deep, C16NUM);
for (int c = 0; c < col4; c++) { for (int c = 0; c < col4; c++) {
int c4div = c / C4NUM, c4mod = c % C4NUM; int c4div = c / C4NUM, c4mod = c % C4NUM;

View File

@ -27,8 +27,8 @@
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
void DeConvPackWeightSum(int8_t *weight, int32_t *weight_sum, int32_t input_zp, int32_t filter_zp, int deep16, int col4, void DeConvPackWeightSum(const int8_t *weight, int32_t *weight_sum, int32_t input_zp, int32_t filter_zp, int deep16,
bool suppport_opt); int col4, bool suppport_opt);
void DeConvPackInputSum(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16, void DeConvPackInputSum(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16,
bool suppport_opt); bool suppport_opt);
void DeConvWeightTransInt8(const int8_t *src, int8_t *dst, int input_channel, int output_channel, int plane, void DeConvWeightTransInt8(const int8_t *src, int8_t *dst, int input_channel, int output_channel, int plane,

View File

@ -17,7 +17,7 @@
#include "nnacl/int8/layer_norm_int8.h" #include "nnacl/int8/layer_norm_int8.h"
void LayerNormGammaAndBetaInt8(int8_t *dst, const int8_t *src, const float *gamma_data, const float *beta_data, void LayerNormGammaAndBetaInt8(int8_t *dst, const int8_t *src, const float *gamma_data, const float *beta_data,
LayerNormQuantArg *quant, int num, const float mean, const float deno) { const LayerNormQuantArg *quant, int num, const float mean, const float deno) {
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
float fp32_src = (src[i] - quant->in_zp_) * quant->in_scale_; float fp32_src = (src[i] - quant->in_zp_) * quant->in_scale_;
float fp32_dst = (fp32_src - mean) * deno; float fp32_dst = (fp32_src - mean) * deno;
@ -33,7 +33,7 @@ void LayerNormGammaAndBetaInt8(int8_t *dst, const int8_t *src, const float *gamm
* *
* */ * */
int LayerNormInt8(const int8_t *src_data, const float *gamma_data, const float *beta_data, int8_t *dst_data, int LayerNormInt8(const int8_t *src_data, const float *gamma_data, const float *beta_data, int8_t *dst_data,
LayerNormParameter *param, LayerNormQuantArg *quant, int task_id) { const LayerNormParameter *param, const LayerNormQuantArg *quant, int task_id) {
if (src_data == NULL || dst_data == NULL || gamma_data == NULL || beta_data == NULL) { if (src_data == NULL || dst_data == NULL || gamma_data == NULL || beta_data == NULL) {
return NNACL_NULL_PTR; return NNACL_NULL_PTR;
} }

View File

@ -25,7 +25,7 @@ extern "C" {
#endif #endif
int LayerNormInt8(const int8_t *src_data, const float *gamma_data, const float *beta_data, int8_t *dst_data, int LayerNormInt8(const int8_t *src_data, const float *gamma_data, const float *beta_data, int8_t *dst_data,
LayerNormParameter *param, LayerNormQuantArg *quant, int task_id); const LayerNormParameter *param, const LayerNormQuantArg *quant, int task_id);
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@ -92,7 +92,7 @@ void RowMajor2Row4x16MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row,
} }
} }
void RowMajor2Row16x4MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { void RowMajor2Row16x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) {
/* Row-major to row16x4-major (block row-major) */ /* Row-major to row16x4-major (block row-major) */
int col16 = UP_ROUND(col, C16NUM); int col16 = UP_ROUND(col, C16NUM);
size_t row_4div = row / C4NUM * C4NUM; size_t row_4div = row / C4NUM * C4NUM;
@ -231,8 +231,9 @@ void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row,
#ifndef ENABLE_ARM #ifndef ENABLE_ARM
void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums, void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums,
const int *bias, int mini, int maxi, int out_zp, int32_t *multiplier, int32_t *left_shift, const int *bias, int mini, int maxi, int out_zp, const int32_t *multiplier,
int32_t *right_shift, size_t stride, size_t filter_peroc, int32_t *filter_zp) { const int32_t *left_shift, const int32_t *right_shift, size_t stride, size_t filter_peroc,
const int32_t *filter_zp) {
/* /*
* row4x16-major * row16x4-major => (int8)row-major * row4x16-major * row16x4-major => (int8)row-major
* support per-layer && weight per-channel * support per-layer && weight per-channel

View File

@ -28,14 +28,15 @@ extern "C" {
/* matmul */ /* matmul */
void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16, void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16,
const int *input_sum, const int *bias); const int *input_sum, const int *bias);
void RowMajor2Row16x4MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col); void RowMajor2Row16x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
void RowMajor2Col16x4MajorInt8(int8_t *src, int row, int col, int8_t *dst); void RowMajor2Col16x4MajorInt8(int8_t *src, int row, int col, int8_t *dst);
void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst, DataOrder order); void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst, DataOrder order);
void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int *weight_zp_ptr, const int *bias, int *dst, void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int *weight_zp_ptr, const int *bias, int *dst,
DataOrder order, bool filter_per_channel); DataOrder order, bool filter_per_channel);
void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums, void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums,
const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift, const int *bias, int act_min, int act_max, int out_zp, const int32_t *multiplier,
int32_t *right_shift, size_t stride, size_t filter_peroc, int32_t *filter_zp); const int32_t *left_shift, const int32_t *right_shift, size_t stride, size_t filter_peroc,
const int32_t *filter_zp);
/* 8x4 4x8 -> 8x8 */ /* 8x4 4x8 -> 8x8 */
/* optimize conv */ /* optimize conv */

View File

@ -16,7 +16,7 @@
#include "nnacl/int8/power_int8.h" #include "nnacl/int8/power_int8.h"
int PowerInt8(const int8_t *input, int8_t *exp_ptr, int8_t *output, int count, PowerParameter *param) { int PowerInt8(const int8_t *input, const int8_t *exp_ptr, int8_t *output, int count, PowerParameter *param) {
double input_scale = param->quant_arg_.in_args_.scale_; double input_scale = param->quant_arg_.in_args_.scale_;
int input_zp = param->quant_arg_.in_args_.zp_; int input_zp = param->quant_arg_.in_args_.zp_;
double output_scale = param->quant_arg_.out_args_.scale_; double output_scale = param->quant_arg_.out_args_.scale_;

View File

@ -24,7 +24,7 @@
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
int PowerInt8(const int8_t *input_ptr, int8_t *exp_ptr, int8_t *output_ptr, int count, PowerParameter *parameter); int PowerInt8(const int8_t *input_ptr, const int8_t *exp_ptr, int8_t *output_ptr, int count, PowerParameter *parameter);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -17,7 +17,8 @@
#include "nnacl/int8/unsqueeze_int8.h" #include "nnacl/int8/unsqueeze_int8.h"
#include "nnacl/unsqueeze_parameter.h" #include "nnacl/unsqueeze_parameter.h"
int Int8Unsqueeze(int8_t *input_ptr, int8_t *output_ptr, UnSqueezeParameter *para_, size_t data_size, int task_id) { int Int8Unsqueeze(const int8_t *input_ptr, int8_t *output_ptr, UnSqueezeParameter *para_, size_t data_size,
int task_id) {
float output_scale = para_->quant_arg.out_quant_args_.scale_; float output_scale = para_->quant_arg.out_quant_args_.scale_;
int8_t output_zp = para_->quant_arg.out_quant_args_.zp_; int8_t output_zp = para_->quant_arg.out_quant_args_.zp_;
float input_scale = para_->quant_arg.in_quant_args_.scale_; float input_scale = para_->quant_arg.in_quant_args_.scale_;

View File

@ -24,7 +24,8 @@
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
int Int8Unsqueeze(int8_t *input_ptr, int8_t *output_ptr, UnSqueezeParameter *para_, size_t data_size, int task_id); int Int8Unsqueeze(const int8_t *input_ptr, int8_t *output_ptr, UnSqueezeParameter *para_, size_t data_size,
int task_id);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif