forked from mindspore-Ecosystem/mindspore
improve code nnacl/int8
This commit is contained in:
parent
2d8e44f3d0
commit
a9ae754a19
|
@ -452,7 +452,8 @@ void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weigh
|
||||||
}
|
}
|
||||||
|
|
||||||
void Conv3x3Int8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound,
|
void Conv3x3Int8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound,
|
||||||
bool w_not_bound, int output_w, int real_num, int oc_start, ConvParameter *conv_param) {
|
bool w_not_bound, int output_w, int real_num, int oc_start,
|
||||||
|
const ConvParameter *conv_param) {
|
||||||
int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_;
|
int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_;
|
||||||
int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_;
|
int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_;
|
||||||
int32_t *quant_multiplier = conv_param->conv_quant_arg_.quant_multiplier_;
|
int32_t *quant_multiplier = conv_param->conv_quant_arg_.quant_multiplier_;
|
||||||
|
@ -745,7 +746,7 @@ void Conv3x3Int8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, in
|
||||||
}
|
}
|
||||||
|
|
||||||
void Conv3x3Int8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const int32_t *bias_data, int start_index,
|
void Conv3x3Int8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const int32_t *bias_data, int start_index,
|
||||||
int real_cal_num, int out_w_block, ConvParameter *conv_param) {
|
int real_cal_num, int out_w_block, const ConvParameter *conv_param) {
|
||||||
int output_channel = conv_param->output_channel_;
|
int output_channel = conv_param->output_channel_;
|
||||||
int output_w = conv_param->output_w_;
|
int output_w = conv_param->output_w_;
|
||||||
int output_h = conv_param->output_h_;
|
int output_h = conv_param->output_h_;
|
||||||
|
@ -778,7 +779,7 @@ void Conv3x3Int8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const
|
||||||
}
|
}
|
||||||
|
|
||||||
void Conv3x3Int8InputTransform(const int16_t *input_data, int16_t *trans_input, int16_t *tmp_data, int start_index,
|
void Conv3x3Int8InputTransform(const int16_t *input_data, int16_t *trans_input, int16_t *tmp_data, int start_index,
|
||||||
int real_cal_num, int out_w_block, ConvParameter *conv_param) {
|
int real_cal_num, int out_w_block, const ConvParameter *conv_param) {
|
||||||
// input data format : nhwc
|
// input data format : nhwc
|
||||||
int input_channel = conv_param->input_channel_;
|
int input_channel = conv_param->input_channel_;
|
||||||
int input_width = conv_param->input_w_;
|
int input_width = conv_param->input_w_;
|
||||||
|
@ -868,7 +869,7 @@ void Conv3x3Int8Gemm(int32_t *dst, const int16_t *src, const int16_t *weight, in
|
||||||
// int8 convolution 3x3
|
// int8 convolution 3x3
|
||||||
void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data,
|
void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data,
|
||||||
int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out,
|
int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out,
|
||||||
int task_id, ConvParameter *conv_param) {
|
int task_id, const ConvParameter *conv_param) {
|
||||||
int ic8 = UP_DIV(conv_param->input_channel_, C8NUM);
|
int ic8 = UP_DIV(conv_param->input_channel_, C8NUM);
|
||||||
int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT);
|
int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT);
|
||||||
int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT);
|
int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT);
|
||||||
|
|
|
@ -39,7 +39,7 @@ void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weigh
|
||||||
|
|
||||||
void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data,
|
void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data,
|
||||||
int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out,
|
int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out,
|
||||||
int task_id, ConvParameter *conv_param);
|
int task_id, const ConvParameter *conv_param);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
|
|
|
@ -154,8 +154,8 @@ void ConvDw3x3Int8InitBuffer(int8_t *buffer, const int8_t *input, const ConvPara
|
||||||
|
|
||||||
void ConvDw3x3Int8Window(int8_t *output, const int8_t *buffer, const int16_t *weight, const int32_t *bias, int col_size,
|
void ConvDw3x3Int8Window(int8_t *output, const int8_t *buffer, const int16_t *weight, const int32_t *bias, int col_size,
|
||||||
int row_size, int channel, int output_h, int output_w, int8_t in_zp, int32_t out_zp,
|
int row_size, int channel, int output_h, int output_w, int8_t in_zp, int32_t out_zp,
|
||||||
int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, int32_t acc_min,
|
const int32_t *out_multiplier, const int32_t *left_shift, const int32_t *right_shift,
|
||||||
int32_t acc_max, int stride, bool per_channel) {
|
int32_t acc_min, int32_t acc_max, int stride, bool per_channel) {
|
||||||
for (int w = 0; w < output_w; w++) {
|
for (int w = 0; w < output_w; w++) {
|
||||||
int tmp_buffer[C8NUM];
|
int tmp_buffer[C8NUM];
|
||||||
for (int i = 0; i < C8NUM; i++) {
|
for (int i = 0; i < C8NUM; i++) {
|
||||||
|
@ -330,8 +330,8 @@ void ConvDw3x3Int8(int8_t *output_data, int8_t *buffer, const int8_t *input_data
|
||||||
#ifndef ENABLE_ARM32
|
#ifndef ENABLE_ARM32
|
||||||
void ConvDw3x3Int8BorderPixel(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height,
|
void ConvDw3x3Int8BorderPixel(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height,
|
||||||
int width, int in_kh_step, int in_kw_step, int channel, int8_t in_zp, int32_t out_zp,
|
int width, int in_kh_step, int in_kw_step, int channel, int8_t in_zp, int32_t out_zp,
|
||||||
int *out_multiplier, int *left_shift, int *right_shift, int32_t acc_min, int32_t acc_max,
|
const int *out_multiplier, const int *left_shift, const int *right_shift, int32_t acc_min,
|
||||||
bool per_channel) {
|
int32_t acc_max, bool per_channel) {
|
||||||
for (int c = 0; c < channel; c += 8) {
|
for (int c = 0; c < channel; c += 8) {
|
||||||
int tmp_buffer[8];
|
int tmp_buffer[8];
|
||||||
for (int i = 0; i < 8; i++) {
|
for (int i = 0; i < 8; i++) {
|
||||||
|
@ -563,8 +563,9 @@ void ConvDwInt8Border(int8_t *dst, const int8_t *src, const int16_t *weight, con
|
||||||
#ifndef ENABLE_ARM
|
#ifndef ENABLE_ARM
|
||||||
void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, int width,
|
void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, int width,
|
||||||
int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step,
|
int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step,
|
||||||
int in_kh_step, int in_kw_step, int8_t *in_zp, int32_t *out_zp, int32_t *out_multiplier,
|
int in_kh_step, int in_kw_step, const int8_t *in_zp, const int32_t *out_zp,
|
||||||
int32_t *left_shift, int32_t *right_shift, int32_t *acc_min, int32_t *acc_max) {
|
const int32_t *out_multiplier, const int32_t *left_shift, const int32_t *right_shift,
|
||||||
|
const int32_t *acc_min, const int32_t *acc_max) {
|
||||||
int tmp_buffer[C8NUM];
|
int tmp_buffer[C8NUM];
|
||||||
int8_t *dst_h = dst;
|
int8_t *dst_h = dst;
|
||||||
const int8_t *src_h = src;
|
const int8_t *src_h = src;
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
#include "nnacl/int8/matmul_int8.h"
|
#include "nnacl/int8/matmul_int8.h"
|
||||||
#include "nnacl/int8/common_func_int8.h"
|
#include "nnacl/int8/common_func_int8.h"
|
||||||
int DeConvPostInt8C4(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel,
|
int DeConvPostInt8C4(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel,
|
||||||
ConvParameter *conv_param) {
|
const ConvParameter *conv_param) {
|
||||||
/* row4x4-major(ih*iw x oc*kh*kw) -> row4-major(oh*ow x oc) */
|
/* row4x4-major(ih*iw x oc*kh*kw) -> row4-major(oh*ow x oc) */
|
||||||
size_t input_plane = conv_param->input_w_ * conv_param->input_h_;
|
size_t input_plane = conv_param->input_w_ * conv_param->input_h_;
|
||||||
size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_;
|
size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_;
|
||||||
|
@ -107,8 +107,8 @@ void DeConvWeightTransInt8(const int8_t *src, int8_t *dst, int input_channel, in
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
void DeConvPackWeightSum(int8_t *weight, int32_t *weight_sum, int32_t input_zp, int32_t filter_zp, int deep, int col4,
|
void DeConvPackWeightSum(const int8_t *weight, int32_t *weight_sum, int32_t input_zp, int32_t filter_zp, int deep,
|
||||||
bool suppport_opt) {
|
int col4, bool suppport_opt) {
|
||||||
int deep16 = UP_ROUND(deep, C16NUM);
|
int deep16 = UP_ROUND(deep, C16NUM);
|
||||||
for (int c = 0; c < col4; c++) {
|
for (int c = 0; c < col4; c++) {
|
||||||
int c4div = c / C4NUM, c4mod = c % C4NUM;
|
int c4div = c / C4NUM, c4mod = c % C4NUM;
|
||||||
|
|
|
@ -27,8 +27,8 @@
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
void DeConvPackWeightSum(int8_t *weight, int32_t *weight_sum, int32_t input_zp, int32_t filter_zp, int deep16, int col4,
|
void DeConvPackWeightSum(const int8_t *weight, int32_t *weight_sum, int32_t input_zp, int32_t filter_zp, int deep16,
|
||||||
bool suppport_opt);
|
int col4, bool suppport_opt);
|
||||||
void DeConvPackInputSum(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16,
|
void DeConvPackInputSum(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16,
|
||||||
bool suppport_opt);
|
bool suppport_opt);
|
||||||
void DeConvWeightTransInt8(const int8_t *src, int8_t *dst, int input_channel, int output_channel, int plane,
|
void DeConvWeightTransInt8(const int8_t *src, int8_t *dst, int input_channel, int output_channel, int plane,
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
#include "nnacl/int8/layer_norm_int8.h"
|
#include "nnacl/int8/layer_norm_int8.h"
|
||||||
|
|
||||||
void LayerNormGammaAndBetaInt8(int8_t *dst, const int8_t *src, const float *gamma_data, const float *beta_data,
|
void LayerNormGammaAndBetaInt8(int8_t *dst, const int8_t *src, const float *gamma_data, const float *beta_data,
|
||||||
LayerNormQuantArg *quant, int num, const float mean, const float deno) {
|
const LayerNormQuantArg *quant, int num, const float mean, const float deno) {
|
||||||
for (int i = 0; i < num; i++) {
|
for (int i = 0; i < num; i++) {
|
||||||
float fp32_src = (src[i] - quant->in_zp_) * quant->in_scale_;
|
float fp32_src = (src[i] - quant->in_zp_) * quant->in_scale_;
|
||||||
float fp32_dst = (fp32_src - mean) * deno;
|
float fp32_dst = (fp32_src - mean) * deno;
|
||||||
|
@ -33,7 +33,7 @@ void LayerNormGammaAndBetaInt8(int8_t *dst, const int8_t *src, const float *gamm
|
||||||
*
|
*
|
||||||
* */
|
* */
|
||||||
int LayerNormInt8(const int8_t *src_data, const float *gamma_data, const float *beta_data, int8_t *dst_data,
|
int LayerNormInt8(const int8_t *src_data, const float *gamma_data, const float *beta_data, int8_t *dst_data,
|
||||||
LayerNormParameter *param, LayerNormQuantArg *quant, int task_id) {
|
const LayerNormParameter *param, const LayerNormQuantArg *quant, int task_id) {
|
||||||
if (src_data == NULL || dst_data == NULL || gamma_data == NULL || beta_data == NULL) {
|
if (src_data == NULL || dst_data == NULL || gamma_data == NULL || beta_data == NULL) {
|
||||||
return NNACL_NULL_PTR;
|
return NNACL_NULL_PTR;
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,7 +25,7 @@ extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
int LayerNormInt8(const int8_t *src_data, const float *gamma_data, const float *beta_data, int8_t *dst_data,
|
int LayerNormInt8(const int8_t *src_data, const float *gamma_data, const float *beta_data, int8_t *dst_data,
|
||||||
LayerNormParameter *param, LayerNormQuantArg *quant, int task_id);
|
const LayerNormParameter *param, const LayerNormQuantArg *quant, int task_id);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
|
|
|
@ -92,7 +92,7 @@ void RowMajor2Row4x16MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void RowMajor2Row16x4MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) {
|
void RowMajor2Row16x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) {
|
||||||
/* Row-major to row16x4-major (block row-major) */
|
/* Row-major to row16x4-major (block row-major) */
|
||||||
int col16 = UP_ROUND(col, C16NUM);
|
int col16 = UP_ROUND(col, C16NUM);
|
||||||
size_t row_4div = row / C4NUM * C4NUM;
|
size_t row_4div = row / C4NUM * C4NUM;
|
||||||
|
@ -231,8 +231,9 @@ void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row,
|
||||||
|
|
||||||
#ifndef ENABLE_ARM
|
#ifndef ENABLE_ARM
|
||||||
void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums,
|
void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums,
|
||||||
const int *bias, int mini, int maxi, int out_zp, int32_t *multiplier, int32_t *left_shift,
|
const int *bias, int mini, int maxi, int out_zp, const int32_t *multiplier,
|
||||||
int32_t *right_shift, size_t stride, size_t filter_peroc, int32_t *filter_zp) {
|
const int32_t *left_shift, const int32_t *right_shift, size_t stride, size_t filter_peroc,
|
||||||
|
const int32_t *filter_zp) {
|
||||||
/*
|
/*
|
||||||
* row4x16-major * row16x4-major => (int8)row-major
|
* row4x16-major * row16x4-major => (int8)row-major
|
||||||
* support per-layer && weight per-channel
|
* support per-layer && weight per-channel
|
||||||
|
|
|
@ -28,14 +28,15 @@ extern "C" {
|
||||||
/* matmul */
|
/* matmul */
|
||||||
void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16,
|
void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16,
|
||||||
const int *input_sum, const int *bias);
|
const int *input_sum, const int *bias);
|
||||||
void RowMajor2Row16x4MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
|
void RowMajor2Row16x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
|
||||||
void RowMajor2Col16x4MajorInt8(int8_t *src, int row, int col, int8_t *dst);
|
void RowMajor2Col16x4MajorInt8(int8_t *src, int row, int col, int8_t *dst);
|
||||||
void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst, DataOrder order);
|
void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst, DataOrder order);
|
||||||
void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int *weight_zp_ptr, const int *bias, int *dst,
|
void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int *weight_zp_ptr, const int *bias, int *dst,
|
||||||
DataOrder order, bool filter_per_channel);
|
DataOrder order, bool filter_per_channel);
|
||||||
void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums,
|
void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums,
|
||||||
const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift,
|
const int *bias, int act_min, int act_max, int out_zp, const int32_t *multiplier,
|
||||||
int32_t *right_shift, size_t stride, size_t filter_peroc, int32_t *filter_zp);
|
const int32_t *left_shift, const int32_t *right_shift, size_t stride, size_t filter_peroc,
|
||||||
|
const int32_t *filter_zp);
|
||||||
|
|
||||||
/* 8x4 4x8 -> 8x8 */
|
/* 8x4 4x8 -> 8x8 */
|
||||||
/* optimize conv */
|
/* optimize conv */
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
#include "nnacl/int8/power_int8.h"
|
#include "nnacl/int8/power_int8.h"
|
||||||
|
|
||||||
int PowerInt8(const int8_t *input, int8_t *exp_ptr, int8_t *output, int count, PowerParameter *param) {
|
int PowerInt8(const int8_t *input, const int8_t *exp_ptr, int8_t *output, int count, PowerParameter *param) {
|
||||||
double input_scale = param->quant_arg_.in_args_.scale_;
|
double input_scale = param->quant_arg_.in_args_.scale_;
|
||||||
int input_zp = param->quant_arg_.in_args_.zp_;
|
int input_zp = param->quant_arg_.in_args_.zp_;
|
||||||
double output_scale = param->quant_arg_.out_args_.scale_;
|
double output_scale = param->quant_arg_.out_args_.scale_;
|
||||||
|
|
|
@ -24,7 +24,7 @@
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
int PowerInt8(const int8_t *input_ptr, int8_t *exp_ptr, int8_t *output_ptr, int count, PowerParameter *parameter);
|
int PowerInt8(const int8_t *input_ptr, const int8_t *exp_ptr, int8_t *output_ptr, int count, PowerParameter *parameter);
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -17,7 +17,8 @@
|
||||||
#include "nnacl/int8/unsqueeze_int8.h"
|
#include "nnacl/int8/unsqueeze_int8.h"
|
||||||
#include "nnacl/unsqueeze_parameter.h"
|
#include "nnacl/unsqueeze_parameter.h"
|
||||||
|
|
||||||
int Int8Unsqueeze(int8_t *input_ptr, int8_t *output_ptr, UnSqueezeParameter *para_, size_t data_size, int task_id) {
|
int Int8Unsqueeze(const int8_t *input_ptr, int8_t *output_ptr, UnSqueezeParameter *para_, size_t data_size,
|
||||||
|
int task_id) {
|
||||||
float output_scale = para_->quant_arg.out_quant_args_.scale_;
|
float output_scale = para_->quant_arg.out_quant_args_.scale_;
|
||||||
int8_t output_zp = para_->quant_arg.out_quant_args_.zp_;
|
int8_t output_zp = para_->quant_arg.out_quant_args_.zp_;
|
||||||
float input_scale = para_->quant_arg.in_quant_args_.scale_;
|
float input_scale = para_->quant_arg.in_quant_args_.scale_;
|
||||||
|
|
|
@ -24,7 +24,8 @@
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
int Int8Unsqueeze(int8_t *input_ptr, int8_t *output_ptr, UnSqueezeParameter *para_, size_t data_size, int task_id);
|
int Int8Unsqueeze(const int8_t *input_ptr, int8_t *output_ptr, UnSqueezeParameter *para_, size_t data_size,
|
||||||
|
int task_id);
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
Loading…
Reference in New Issue