Add int8 conv weight per channel

This commit is contained in:
fuzhiye 2020-08-14 10:10:42 +08:00
parent c7b50bcdd2
commit 72953307d9
10 changed files with 573 additions and 174 deletions

View File

@ -15,6 +15,7 @@
*/
#include "src/runtime/kernel/arm/base/convolution_base.h"
#include <float.h>
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "include/errorcode.h"
@ -66,13 +67,14 @@ void ConvolutionBaseCPUKernel::FreeQuantParam() {
free(conv_quant_arg_->out_act_max_);
conv_quant_arg_->out_act_max_ = nullptr;
}
if (conv_quant_arg_->quant_args_ != nullptr) {
for (int i = 0; i < 3; ++i) {
if (*(conv_quant_arg_->quant_args_ + i) != nullptr) {
free(*(conv_quant_arg_->quant_args_ + i));
}
}
if (conv_quant_arg_->input_quant_args_ != nullptr) {
free(conv_quant_arg_->input_quant_args_);
}
if (conv_quant_arg_->filter_quant_args_ != nullptr) {
free(conv_quant_arg_->filter_quant_args_);
}
if (conv_quant_arg_->output_quant_args_ != nullptr) {
free(conv_quant_arg_->output_quant_args_);
}
}
@ -103,53 +105,218 @@ int ConvolutionBaseCPUKernel::CheckLayout(lite::tensor::Tensor *input_tensor) {
return RET_OK;
}
int ConvolutionBaseCPUKernel::SetQuantParam() {
ConvQuantArg *conv_quant_arg_ = &conv_param_->conv_quant_arg_;
conv_quant_arg_->quant_args_ = reinterpret_cast<QuantArg **>(malloc(3 * sizeof(QuantArg *)));
if (conv_quant_arg_->quant_args_ == nullptr) {
MS_LOG(ERROR) << "malloc quant_args_ failed.";
return RET_ERROR;
}
// per-tensor init
for (int j = 0; j < 3; ++j) {
conv_quant_arg_->quant_args_[j] = reinterpret_cast<QuantArg *>(malloc(sizeof(QuantArg)));
if (conv_quant_arg_->quant_args_[j] == nullptr) {
MS_LOG(ERROR) << "malloc quant_args_ failed.";
int ConvolutionBaseCPUKernel::SetIfPerChannel() {
uint8_t per_channel = 0b0;
if (conv_quant_arg_->input_arg_num_ != kPerTensor) {
int in_channel = conv_param_->input_channel_;
if (conv_quant_arg_->input_arg_num_ != in_channel) {
MS_LOG(ERROR) << "input per channel quant param length is not equal to input channel.";
return RET_ERROR;
}
per_channel = per_channel | INPUT_PER_CHANNEL;
}
if (conv_quant_arg_->filter_arg_num_ != kPerTensor) {
int filter_num = conv_param_->output_channel_;
if (conv_quant_arg_->filter_arg_num_ != filter_num) {
MS_LOG(ERROR) << "weight per channel quant param length is not equal to filter num.";
return RET_ERROR;
}
per_channel = per_channel | FILTER_PER_CHANNEL;
}
if (conv_quant_arg_->output_arg_num_ != kPerTensor) {
int out_channel = conv_param_->output_channel_;
if (conv_quant_arg_->output_arg_num_ != out_channel) {
MS_LOG(ERROR) << "output per channel quant param length is not equal to output channel.";
return RET_ERROR;
}
per_channel = per_channel | OUTPUT_PER_CHANNEL;
}
conv_quant_arg_->per_channel_ = per_channel;
return RET_OK;
}
int ConvolutionBaseCPUKernel::SetIfAsymmetric() {
uint8_t asymmetric = 0b0;
auto filter_tensor = in_tensors_.at(kWeightIndex);
auto filter_ele_num = filter_tensor->ElementsNum();
auto filter_data = reinterpret_cast<float *>(filter_tensor->Data());
float min_value = FLT_MAX;
float max_value = -FLT_MAX;
for (int i = 0; i < filter_ele_num; ++i) {
min_value = min_value < filter_data[i] ? min_value : filter_data[i];
max_value = max_value > filter_data[i] ? max_value : filter_data[i];
}
if (conv_quant_arg_->filter_arg_num_ == kPerTensor) {
auto filter_zp = conv_quant_arg_->filter_quant_args_[0].zp_;
if (filter_zp == 0 && min_value >= -127 && max_value <= 127) {
asymmetric = asymmetric & FILTER_ASYMMETRIC;
}
} else {
auto filter_arg = conv_quant_arg_->filter_quant_args_;
for (int i = 0; i < conv_param_->output_channel_; ++i) {
if (filter_arg[i].zp_ == 0 && min_value >= -127 && max_value <= 127) {
asymmetric = asymmetric & FILTER_ASYMMETRIC;
}
}
}
conv_quant_arg_->asymmetric_ = asymmetric;
return RET_OK;
}
int ConvolutionBaseCPUKernel::MallocQuantParam() {
conv_quant_arg_ = &conv_param_->conv_quant_arg_;
auto input_tensor = in_tensors_.at(kInputIndex);
auto weight_tensor = in_tensors_.at(kWeightIndex);
auto output_tensor = out_tensors_.at(kOutputIndex);
auto input_quant_arg = input_tensor->GetQuantParams().front();
auto weight_quant_arg = weight_tensor->GetQuantParams().front();
auto output_quant_arg = output_tensor->GetQuantParams().front();
// input
conv_quant_arg_->quant_args_[0][0].zp_ = input_quant_arg.zeroPoint;
conv_quant_arg_->quant_args_[0][0].scale_ = input_quant_arg.scale;
// weight
conv_quant_arg_->quant_args_[1][0].zp_ = weight_quant_arg.zeroPoint;
conv_quant_arg_->quant_args_[1][0].scale_ = weight_quant_arg.scale;
// output
conv_quant_arg_->quant_args_[2][0].zp_ = output_quant_arg.zeroPoint;
conv_quant_arg_->quant_args_[2][0].scale_ = output_quant_arg.scale;
size_t input_arg_num = input_tensor->GetQuantParams().size();
size_t filter_arg_num = weight_tensor->GetQuantParams().size();
size_t output_arg_num = output_tensor->GetQuantParams().size();
conv_quant_arg_->input_arg_num_ = input_arg_num;
conv_quant_arg_->filter_arg_num_ = filter_arg_num;
conv_quant_arg_->output_arg_num_ = output_arg_num;
conv_quant_arg_->real_multiplier_ = reinterpret_cast<double *>(malloc(sizeof(double)));
conv_quant_arg_->left_shift_ = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t)));
conv_quant_arg_->right_shift_ = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t)));
conv_quant_arg_->quant_multiplier_ = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t)));
conv_quant_arg_->input_quant_args_ = reinterpret_cast<QuantArg *>(malloc(input_arg_num * sizeof(QuantArg)));
if (conv_quant_arg_->input_quant_args_ == nullptr) {
MS_LOG(ERROR) << "malloc input_quant_args_ failed.";
return RET_ERROR;
}
conv_quant_arg_->filter_quant_args_ = reinterpret_cast<QuantArg *>(malloc(filter_arg_num * sizeof(QuantArg)));
if (conv_quant_arg_->filter_quant_args_ == nullptr) {
MS_LOG(ERROR) << "malloc filter_quant_args_ failed.";
return RET_ERROR;
}
conv_quant_arg_->output_quant_args_ = reinterpret_cast<QuantArg *>(malloc(output_arg_num * sizeof(QuantArg)));
if (conv_quant_arg_->output_quant_args_ == nullptr) {
MS_LOG(ERROR) << "malloc output_quant_args_ failed.";
return RET_ERROR;
}
return RET_OK;
}
int ConvolutionBaseCPUKernel::SetInputTensorQuantParam() {
auto input_tensor = in_tensors_.at(kInputIndex);
auto in_arg_num = conv_quant_arg_->input_arg_num_;
if (in_arg_num == kPerTensor) {
auto input_quant_arg = input_tensor->GetQuantParams().front();
conv_quant_arg_->input_quant_args_[0].zp_ = input_quant_arg.zeroPoint;
conv_quant_arg_->input_quant_args_[0].scale_ = input_quant_arg.scale;
} else {
// per channel
MS_LOG(ERROR) << "Not Support Per Channel for input now.";
return RET_ERROR;
// auto input_quant_arg = input_tensor->GetQuantParams();
// for (int i = 0; i < in_arg_num; ++i) {
// conv_quant_arg_->input_quant_args_[i].zp_ = input_quant_arg[i].zeroPoint;
// conv_quant_arg_->input_quant_args_[i].scale_ = input_quant_arg[i].scale;
// }
}
return RET_OK;
}
int ConvolutionBaseCPUKernel::SetFilterTensorQuantParam() {
auto weight_tensor = in_tensors_.at(kWeightIndex);
auto weight_arg_num = conv_quant_arg_->filter_arg_num_;
if (weight_arg_num == kPerTensor) {
auto weight_quant_arg = weight_tensor->GetQuantParams().front();
conv_quant_arg_->filter_quant_args_[0].zp_ = weight_quant_arg.zeroPoint;
conv_quant_arg_->filter_quant_args_[0].scale_ = weight_quant_arg.scale;
} else {
auto weight_quant_arg = weight_tensor->GetQuantParams();
for (int i = 0; i < weight_arg_num; ++i) {
conv_quant_arg_->filter_quant_args_[i].zp_ = weight_quant_arg[i].zeroPoint;
conv_quant_arg_->filter_quant_args_[i].scale_ = weight_quant_arg[i].scale;
}
}
return RET_OK;
}
int ConvolutionBaseCPUKernel::SetOutputTensorQuantParam() {
auto output_tensor = out_tensors_.at(kOutputIndex);
auto out_arg_num = conv_quant_arg_->output_arg_num_;
if (out_arg_num == kPerTensor) {
auto output_quant_arg = output_tensor->GetQuantParams().front();
conv_quant_arg_->output_quant_args_[0].zp_ = output_quant_arg.zeroPoint;
conv_quant_arg_->output_quant_args_[0].scale_ = output_quant_arg.scale;
} else {
MS_LOG(ERROR) << "Not Support Per Channel for input now.";
return RET_ERROR;
// auto output_quant_arg = output_tensor->GetQuantParams();
// for (int i = 0; i < out_arg_num; ++i) {
// conv_quant_arg_->output_quant_args_[i].zp_ = output_quant_arg[i].zeroPoint;
// conv_quant_arg_->output_quant_args_[i].scale_ = output_quant_arg[i].scale;
// }
}
return RET_OK;
}
int ConvolutionBaseCPUKernel::SetQuantMultiplier() {
// now only support weight tensor is per channel, others are per tensor.
int weight_arg_num = kPerTensor;
if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) {
weight_arg_num = conv_quant_arg_->filter_arg_num_;
}
conv_quant_arg_->real_multiplier_ = reinterpret_cast<double *>(malloc(weight_arg_num * sizeof(double)));
conv_quant_arg_->left_shift_ = reinterpret_cast<int32_t *>(malloc(weight_arg_num * sizeof(int32_t)));
conv_quant_arg_->right_shift_ = reinterpret_cast<int32_t *>(malloc(weight_arg_num * sizeof(int32_t)));
conv_quant_arg_->quant_multiplier_ = reinterpret_cast<int32_t *>(malloc(weight_arg_num * sizeof(int32_t)));
conv_quant_arg_->out_act_min_ = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t)));
conv_quant_arg_->out_act_max_ = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t)));
double real_multiplier = weight_quant_arg.scale * input_quant_arg.scale / output_quant_arg.scale;
conv_quant_arg_->real_multiplier_[0] = real_multiplier;
QuantizeRoundParameter(real_multiplier, &conv_quant_arg_->quant_multiplier_[0], &conv_quant_arg_->left_shift_[0],
&conv_quant_arg_->right_shift_[0]);
for (int i = 0; i < weight_arg_num; ++i) {
double real_multiplier = conv_quant_arg_->filter_quant_args_[i].scale_ *
conv_quant_arg_->input_quant_args_[0].scale_ /
conv_quant_arg_->output_quant_args_[0].scale_;
conv_quant_arg_->real_multiplier_[i] = real_multiplier;
QuantizeRoundParameter(real_multiplier, &conv_quant_arg_->quant_multiplier_[i], &conv_quant_arg_->left_shift_[i],
&conv_quant_arg_->right_shift_[i]);
}
return RET_OK;
}
int ConvolutionBaseCPUKernel::SetQuantParam() {
auto ret = MallocQuantParam();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Malloc quant param failed.";
return ret;
}
ret = SetInputTensorQuantParam();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Set Input Tensor Quant Param Failed.";
return ret;
}
ret = SetFilterTensorQuantParam();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Set Filter Tensor Quant Param Failed.";
return ret;
}
ret = SetOutputTensorQuantParam();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Set Output Tensor Quant Param Failed.";
return ret;
}
ret = SetQuantMultiplier();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Set Quant Multiplier Failed.";
return ret;
}
// now only consider per tensor for output
CalculateActivationRangeQuantized(
conv_param_->is_relu_, conv_param_->is_relu6_, conv_param_->conv_quant_arg_.quant_args_[2][0].zp_,
conv_param_->conv_quant_arg_.quant_args_[2][0].scale_, &conv_param_->conv_quant_arg_.out_act_min_[0],
conv_param_->is_relu_, conv_param_->is_relu6_, conv_param_->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param_->conv_quant_arg_.output_quant_args_[0].scale_, &conv_param_->conv_quant_arg_.out_act_min_[0],
&conv_param_->conv_quant_arg_.out_act_max_[0]);
ret = SetIfPerChannel();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Set if per tensor channel failed.";
return ret;
}
ret = SetIfAsymmetric();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Set if per asymmetric failed.";
return ret;
}
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -32,6 +32,7 @@
using mindspore::lite::Context;
using mindspore::schema::PadMode;
using mindspore::schema::QuantType;
static constexpr int kPerTensor = 1;
namespace mindspore::kernel {
class ConvolutionBaseCPUKernel : public LiteKernel {
@ -49,7 +50,14 @@ class ConvolutionBaseCPUKernel : public LiteKernel {
int ReSize() override { return 0; }
int Run() override { return 0; }
virtual int CheckLayout(lite::tensor::Tensor *input_tensor);
int SetIfAsymmetric();
int SetIfPerChannel();
int MallocQuantParam();
int SetQuantParam();
int SetInputTensorQuantParam();
int SetFilterTensorQuantParam();
int SetOutputTensorQuantParam();
int SetQuantMultiplier();
void FreeQuantParam();
protected:
@ -59,9 +67,9 @@ class ConvolutionBaseCPUKernel : public LiteKernel {
void *nhwc4_input_ = nullptr;
const Context *ctx_;
ConvParameter *conv_param_;
ConvQuantArg *conv_quant_arg_;
LayoutConvertor convert_func_;
};
bool CheckSupportFP16();
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CONVOLUTION_BASE_H_

View File

@ -69,8 +69,8 @@ int ConvolutionInt8CPUKernel::InitWeightBias() {
int kernel_plane = kernel_h * kernel_w;
int plane_c4 = UP_DIV(kernel_plane, C4NUM);
int pack_weight_size = oc4 * ic4 * C4NUM * C4NUM * plane_c4 * C4NUM;
int32_t filter_zp = conv_param_->conv_quant_arg_.quant_args_[1][0].zp_;
int32_t input_zp = conv_param_->conv_quant_arg_.quant_args_[0][0].zp_;
auto filter_arg = conv_param_->conv_quant_arg_.filter_quant_args_;
int32_t input_zp = conv_param_->conv_quant_arg_.input_quant_args_[0].zp_;
// init weight
auto origin_weight = reinterpret_cast<int8_t *>(in_tensors_.at(kWeightIndex)->Data());
@ -99,8 +99,14 @@ int ConvolutionInt8CPUKernel::InitWeightBias() {
}
auto *bias_data = reinterpret_cast<int32_t *>(bias_data_);
int c4_kernel_plane_size = kernel_plane * ic4 * C4NUM;
for (int i = 0; i < out_channel; i++) {
bias_data[i] += filter_zp * input_zp * c4_kernel_plane_size - weight_sum[i] * input_zp;
if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) {
for (int i = 0; i < out_channel; i++) {
bias_data[i] += filter_arg[i].zp_ * input_zp * c4_kernel_plane_size - weight_sum[i] * input_zp;
}
} else {
for (int i = 0; i < out_channel; i++) {
bias_data[i] += filter_arg[0].zp_ * input_zp * c4_kernel_plane_size - weight_sum[i] * input_zp;
}
}
free(weight_sum);
return RET_OK;
@ -125,7 +131,13 @@ int ConvolutionInt8CPUKernel::InitTmpBuffer() {
memset(packed_input_, 0, conv_param_->input_batch_ * packed_input_size);
/*=============================input_sum_============================*/
input_sum_ = reinterpret_cast<int32_t *>(malloc(tile_num_ * thread_count_ * sizeof(int32_t)));
size_t input_sum_size;
if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) {
input_sum_size = conv_param_->output_channel_ * tile_num_ * thread_count_ * sizeof(int32_t);
} else {
input_sum_size = tile_num_ * thread_count_ * sizeof(int32_t);
}
input_sum_ = reinterpret_cast<int32_t *>(malloc(input_sum_size));
if (input_sum_ == nullptr) {
MS_LOG(ERROR) << "malloc input_sum_ failed.";
return RET_ERROR;
@ -168,8 +180,8 @@ int ConvolutionInt8CPUKernel::InitWeightBiasOpt() {
int oc4 = UP_DIV(out_channel, C4NUM);
int kernel_plane = kernel_h * kernel_w;
int pack_weight_size = oc4 * ic4 * C4NUM * C4NUM * kernel_plane;
int32_t filter_zp = conv_param_->conv_quant_arg_.quant_args_[1][0].zp_;
int32_t input_zp = conv_param_->conv_quant_arg_.quant_args_[0][0].zp_;
auto filter_arg = conv_param_->conv_quant_arg_.filter_quant_args_;
int32_t input_zp = conv_param_->conv_quant_arg_.input_quant_args_[0].zp_;
// init weight
auto origin_weight = reinterpret_cast<int8_t *>(in_tensors_.at(kWeightIndex)->Data());
@ -178,9 +190,9 @@ int ConvolutionInt8CPUKernel::InitWeightBiasOpt() {
MS_LOG(ERROR) << "malloc packed_weight_ failed.";
return RET_ERROR;
}
memset(packed_weight_, filter_zp, pack_weight_size);
memset(packed_weight_, 0, pack_weight_size);
auto *weight_sum = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t) * out_channel));
for (int i = 0; i < out_channel; i++) weight_sum[i] = filter_zp * ic4 * C4NUM * kernel_plane;
for (int i = 0; i < out_channel; i++) weight_sum[i] = 0;
PackWeightInt8Opt(origin_weight, conv_param_, packed_weight_, weight_sum);
// init bias
@ -198,8 +210,14 @@ int ConvolutionInt8CPUKernel::InitWeightBiasOpt() {
}
auto *bias_data = reinterpret_cast<int32_t *>(bias_data_);
int c4_kernel_plane_size = kernel_plane * ic4 * C4NUM;
for (int i = 0; i < out_channel; i++) {
bias_data[i] += filter_zp * input_zp * c4_kernel_plane_size - weight_sum[i] * input_zp;
if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) {
for (int i = 0; i < out_channel; i++) {
bias_data[i] += filter_arg[i].zp_ * input_zp * c4_kernel_plane_size - weight_sum[i] * input_zp;
}
} else {
for (int i = 0; i < out_channel; i++) {
bias_data[i] += filter_arg[0].zp_ * input_zp * c4_kernel_plane_size - weight_sum[i] * input_zp;
}
}
free(weight_sum);
return RET_OK;
@ -223,7 +241,13 @@ int ConvolutionInt8CPUKernel::InitTmpBufferOpt() {
memset(packed_input_, 0, conv_param_->input_batch_ * packed_input_size);
/*=============================input_sum_============================*/
input_sum_ = reinterpret_cast<int32_t *>(malloc(tile_num_ * thread_count_ * sizeof(int32_t)));
size_t input_sum_size;
if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) {
input_sum_size = conv_param_->output_channel_ * tile_num_ * thread_count_ * sizeof(int32_t);
} else {
input_sum_size = tile_num_ * thread_count_ * sizeof(int32_t);
}
input_sum_ = reinterpret_cast<int32_t *>(malloc(input_sum_size));
if (input_sum_ == nullptr) {
MS_LOG(ERROR) << "malloc input_sum_ failed.";
return RET_ERROR;

View File

@ -77,7 +77,7 @@ void DepthwiseBorderInt8(int8_t *dst, const int16_t *src, const int16_t *weight,
dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, sliding->in_kh_step_,
sliding->in_kw_step_, conv_param->kernel_w_, conv_param->conv_quant_arg_.quant_multiplier_[0],
conv_param->conv_quant_arg_.left_shift_[0], conv_param->conv_quant_arg_.right_shift_[0],
conv_param->conv_quant_arg_.quant_args_[2][0].zp_, conv_param->conv_quant_arg_.out_act_min_[0],
conv_param->conv_quant_arg_.output_quant_args_[0].zp_, conv_param->conv_quant_arg_.out_act_min_[0],
conv_param->conv_quant_arg_.out_act_max_[0]);
dst_kernel += sliding->block_channel_;
@ -168,15 +168,15 @@ void ConvDwInt8(int8_t *output_data, const int16_t *input_data, const int16_t *w
sliding->in_sw_step_ * sizeof(int16_t), sliding->in_kh_step_ * sizeof(int16_t),
sliding->in_kw_step_ * sizeof(int16_t), conv_param->conv_quant_arg_.quant_multiplier_[0],
conv_param->conv_quant_arg_.left_shift_[0], conv_param->conv_quant_arg_.right_shift_[0],
conv_param->conv_quant_arg_.quant_args_[2][0].zp_, conv_param->conv_quant_arg_.out_act_min_[0],
conv_param->conv_quant_arg_.out_act_max_[0]);
conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]);
#else
DepthwiseCenterInt8(
out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_,
sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_,
conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0],
conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.quant_args_[2][0].zp_,
conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]);
#endif
}
@ -333,7 +333,7 @@ void DeconvDwInt8(int8_t *output_data, int32_t *output_buffer, const int16_t *in
DeconvDepthwisePostFuncInt8(
dst_data, output_buffer, bias, sliding->block_channel_, conv_param,
conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0],
conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.quant_args_[2][0].zp_,
conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]);
} // output C4 loop
src += sliding->in_step_;

View File

@ -22,10 +22,10 @@
void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const int8_t *weight, const int32_t *bias,
int ic4, size_t kernel_plane, size_t output_channel, const int32_t *input_sum,
ConvParameter *conv_param) {
int32_t shift_before = conv_param->conv_quant_arg_.left_shift_[0];
int32_t shift_after = conv_param->conv_quant_arg_.right_shift_[0];
int32_t out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_[0];
int32_t out_zp = conv_param->conv_quant_arg_.quant_args_[2][0].zp_;
int32_t *shift_before = conv_param->conv_quant_arg_.left_shift_;
int32_t *shift_after = conv_param->conv_quant_arg_.right_shift_;
int32_t *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_;
int32_t out_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_;
int32_t act_min = conv_param->conv_quant_arg_.out_act_min_[0];
int32_t act_max = conv_param->conv_quant_arg_.out_act_max_[0];
#ifdef __aarch64__
@ -63,14 +63,49 @@ void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const in
} // in c4num loop
} // ic4 loop
} // kernel_plane loop
tmp_dst[dst_tile_offset] -= input_sum[n];
int result = tmp_dst[dst_tile_offset] + bias[oc];
result = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before), out_multiplier), -shift_after);
result += out_zp;
result = result > act_min ? result : act_min;
result = result < act_max ? result : act_max;
dst[dst_tile_offset] = (int8_t)result;
if (!(conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
int result = tmp_dst[dst_tile_offset] + bias[oc];
result = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before[oc]), out_multiplier[oc]),
-shift_after[oc]);
result += out_zp;
result = result > act_min ? result : act_min;
result = result < act_max ? result : act_max;
dst[dst_tile_offset] = (int8_t)result;
} else if (!(conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
!(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
int result = tmp_dst[dst_tile_offset] + bias[oc];
result = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before[0]), out_multiplier[0]),
-shift_after[0]);
result += out_zp;
result = result > act_min ? result : act_min;
result = result < act_max ? result : act_max;
dst[dst_tile_offset] = (int8_t)result;
} else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
!(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
tmp_dst[dst_tile_offset] -= input_sum[n];
int result = tmp_dst[dst_tile_offset] + bias[oc];
result = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before[0]), out_multiplier[0]),
-shift_after[0]);
result += out_zp;
result = result > act_min ? result : act_min;
result = result < act_max ? result : act_max;
dst[dst_tile_offset] = (int8_t)result;
} else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
tmp_dst[dst_tile_offset] -= input_sum[n * output_channel + oc];
int result = tmp_dst[dst_tile_offset] + bias[oc];
result = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before[oc]), out_multiplier[oc]),
-shift_after[oc]);
result += out_zp;
result = result > act_min ? result : act_min;
result = result < act_max ? result : act_max;
dst[dst_tile_offset] = (int8_t)result;
}
} // tile_num loop
} // output_channel loop
#endif
@ -79,10 +114,10 @@ void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const in
void IndirectGemmInt8Opt(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const int8_t *weight, const int32_t *bias,
int ic4, size_t kernel_plane, size_t output_channel, const int32_t *input_sum,
ConvParameter *conv_param, GEMM_FUNC gemm_func) {
int32_t shift_before = conv_param->conv_quant_arg_.left_shift_[0];
int32_t shift_after = conv_param->conv_quant_arg_.right_shift_[0];
int32_t out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_[0];
int32_t out_zp = conv_param->conv_quant_arg_.quant_args_[2][0].zp_;
int32_t *shift_before = conv_param->conv_quant_arg_.left_shift_;
int32_t *shift_after = conv_param->conv_quant_arg_.right_shift_;
int32_t *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_;
int32_t out_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_;
int32_t act_min = conv_param->conv_quant_arg_.out_act_min_[0];
int32_t act_max = conv_param->conv_quant_arg_.out_act_max_[0];
if (gemm_func != NULL) {
@ -113,14 +148,49 @@ void IndirectGemmInt8Opt(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const
} // in c4num loop
} // ic4 loop
} // kernel_plane loop
tmp_dst[dst_tile_offset] -= input_sum[n];
int result = tmp_dst[dst_tile_offset] + bias[oc];
result = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before), out_multiplier), -shift_after);
result += out_zp;
result = result > act_min ? result : act_min;
result = result < act_max ? result : act_max;
dst[dst_tile_offset] = (int8_t)result;
if (!(conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
int result = tmp_dst[dst_tile_offset] + bias[oc];
result = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before[oc]), out_multiplier[oc]),
-shift_after[oc]);
result += out_zp;
result = result > act_min ? result : act_min;
result = result < act_max ? result : act_max;
dst[dst_tile_offset] = (int8_t)result;
} else if (!(conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
!(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
int result = tmp_dst[dst_tile_offset] + bias[oc];
result = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before[0]), out_multiplier[0]),
-shift_after[0]);
result += out_zp;
result = result > act_min ? result : act_min;
result = result < act_max ? result : act_max;
dst[dst_tile_offset] = (int8_t)result;
} else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
!(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
tmp_dst[dst_tile_offset] -= input_sum[n];
int result = tmp_dst[dst_tile_offset] + bias[oc];
result = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before[0]), out_multiplier[0]),
-shift_after[0]);
result += out_zp;
result = result > act_min ? result : act_min;
result = result < act_max ? result : act_max;
dst[dst_tile_offset] = (int8_t)result;
} else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
tmp_dst[dst_tile_offset] -= input_sum[n * output_channel + oc];
int result = tmp_dst[dst_tile_offset] + bias[oc];
result = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before[oc]), out_multiplier[oc]),
-shift_after[oc]);
result += out_zp;
result = result > act_min ? result : act_min;
result = result < act_max ? result : act_max;
dst[dst_tile_offset] = (int8_t)result;
}
} // tile_num loop
} // output_channel loop
}
@ -182,7 +252,7 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, c
int out_h = conv_param->output_h_;
int out_w = conv_param->output_w_;
int out_channel = conv_param->output_channel_;
int32_t input_zp = conv_param->conv_quant_arg_.quant_args_[0][0].zp_;
int32_t input_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_;
int tile_n = conv_param->tile_num_;
int thread_count = conv_param->thread_num_;
@ -238,7 +308,7 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight
int out_h = conv_param->output_h_;
int out_w = conv_param->output_w_;
int out_channel = conv_param->output_channel_;
int32_t input_zp = conv_param->conv_quant_arg_.quant_args_[0][0].zp_;
int32_t input_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_;
int tile_n = conv_param->tile_num_;
int thread_count = conv_param->thread_num_;
int output_count = out_h * out_w;

View File

@ -19,8 +19,8 @@
int DeConvInt8(const int8_t *input, const int8_t *weight, int32_t *output, size_t row8, size_t col8, size_t deep,
ConvParameter *conv_param) {
MatMulInt8(input, weight, output, row8, col8, deep, conv_param->conv_quant_arg_.quant_args_[0][0].zp_,
conv_param->conv_quant_arg_.quant_args_[1][0].zp_);
MatMulInt8(input, weight, output, row8, col8, deep, conv_param->conv_quant_arg_.input_quant_args_[0].zp_,
conv_param->conv_quant_arg_.filter_quant_args_[0].zp_);
return NNACL_OK;
}
@ -65,7 +65,7 @@ int DeConvPostInt8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t
PostFuncInt8(tmp, bias, out, output_channel, output_plane, UP_ROUND(output_plane, 8),
conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0],
conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.quant_args_[2][0].zp_,
conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]);
return NNACL_OK;
}

View File

@ -115,7 +115,6 @@ void PackWeightInt8Opt(int8_t *weight_data, ConvParameter *conv_param, int8_t *p
int oc4 = UP_DIV(out_channel, C4NUM);
int ic4 = UP_DIV(in_channel, C4NUM);
int kernel_plane = kernel_h * kernel_w;
int32_t filter_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_;
int pack_weight_size = oc4 * ic4 * C4NUM * C4NUM * kernel_plane;
int unit_size = C4NUM * C4NUM;
int block_size = pack_weight_size / oc4;
@ -143,7 +142,7 @@ void PackWeightInt8Opt(int8_t *weight_data, ConvParameter *conv_param, int8_t *p
if (packed_data_ptr[0] == -128) {
packed_data_ptr[0] = -127;
}
weight_sum[j * C4NUM + k] += (int32_t)(packed_data_ptr[0] - filter_zp);
weight_sum[j * C4NUM + k] += (int32_t)(packed_data_ptr[0]);
}
} // kernel block loop
} // inchannel block loop
@ -241,7 +240,7 @@ void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real
int32_t *input_sum, ConvParameter *conv_param) {
// input format : nhwc
int tile_num = conv_param->tile_num_;
int32_t filter_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_;
QuantArg *filter_arg = conv_param->conv_quant_arg_.filter_quant_args_;
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int stride_h = conv_param->stride_h_;
@ -292,7 +291,18 @@ void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real
} // channel_block loop
} // kernel_w loop
} // kernel_h loop
input_sum[i] = input_accumulator * filter_zp;
if (!(conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC)) {
return;
} else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
int cal_num_offset = i * conv_param->output_channel_;
for (int l = 0; l < conv_param->output_channel_; ++l) {
input_sum[cal_num_offset + l] = input_accumulator * filter_arg[i].zp_;
}
} else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
!(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
input_sum[i] = input_accumulator * filter_arg[0].zp_;
}
} // tile num loop
}
@ -300,7 +310,7 @@ void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int r
int32_t *input_sum, ConvParameter *conv_param) {
// input format : nhwc
int tile_num = conv_param->tile_num_;
int32_t filter_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_;
QuantArg *filter_arg = conv_param->conv_quant_arg_.filter_quant_args_;
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int stride_h = conv_param->stride_h_;
@ -348,13 +358,23 @@ void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int r
int block_offset = j * tile_num * ic4 * C4NUM + i * C4NUM;
for (int c = 0; c < ic4; c++) {
int ic4_offset = block_offset + c * tile_num * C4NUM;
input_accumulator += (packed_input + ic4_offset)[0];
input_accumulator += (packed_input + ic4_offset)[1];
input_accumulator += (packed_input + ic4_offset)[2];
input_accumulator += (packed_input + ic4_offset)[3];
for (int k = 0; k < C4NUM; ++k) {
input_accumulator += (packed_input + ic4_offset)[k];
}
}
}
input_sum[i] = input_accumulator * filter_zp;
if (!(conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC)) {
return;
} else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
int cal_num_offset = i * conv_param->output_channel_;
for (int l = 0; l < conv_param->output_channel_; ++l) {
input_sum[cal_num_offset + l] = input_accumulator * filter_arg[i].zp_;
}
} else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
!(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
input_sum[i] = input_accumulator * filter_arg[0].zp_;
}
} // tile num loop
}
@ -387,7 +407,7 @@ void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight
int input_channel = conv_param->input_channel_;
int ic8 = UP_DIV(input_channel, C8NUM);
int output_channel = conv_param->output_channel_;
int filter_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_;
QuantArg *filter_zp = conv_param->conv_quant_arg_.filter_quant_args_;
int kernel_plane = conv_param->kernel_h_ * conv_param->kernel_w_;
for (int k = 0; k < kernel_plane; k++) {
@ -401,7 +421,7 @@ void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight
int c8_block_rem = i % C8NUM;
int src_ic_offset = src_oc_offset + i;
int dst_ic_offset = dst_oc_offset + c8_block_num * kernel_plane * C8NUM + c8_block_rem;
(packed_weight_data + dst_ic_offset)[0] = (int16_t)((origin_weight_data + src_ic_offset)[0] - filter_zp);
(packed_weight_data + dst_ic_offset)[0] = (int16_t)((origin_weight_data + src_ic_offset)[0] - filter_zp[o].zp_);
}
}
}
@ -806,7 +826,7 @@ void MatrixPack(const float *src, float *dst, int row, int ic4, int stride) {
}
void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter *conv_param) {
int input_zp = conv_param->conv_quant_arg_.quant_args_[0][0].zp_;
int input_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_;
int ic4 = UP_DIV(conv_param->input_channel_, C4NUM);
int unit = conv_param->input_h_ * conv_param->input_w_;
@ -824,7 +844,7 @@ void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter
}
void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, const ConvParameter *conv_param) {
int weight_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_;
int weight_zp = conv_param->conv_quant_arg_.filter_quant_args_[0].zp_;
int unit = conv_param->kernel_h_ * conv_param->kernel_w_;
for (int c = 0; c < conv_param->output_channel_; c++) {
int c4_block_num = c / C4NUM;

View File

@ -17,25 +17,37 @@
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_QUANTIZATION_QUANTIZE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_QUANTIZATION_QUANTIZE_H_
#include <stdint.h>
#include <math.h>
#include <stdlib.h>
#include <limits.h>
#include "nnacl/op_base.h"
#define INPUT_ASYMMETRIC 0b001
#define FILTER_ASYMMETRIC 0b010
#define OUTPUT_ASYMMETRIC 0b100
#define INPUT_PER_CHANNEL 0b001
#define FILTER_PER_CHANNEL 0b010
#define OUTPUT_PER_CHANNEL 0b100
typedef struct QuantArg {
double scale_;
int32_t zp_;
} QuantArg;
typedef struct ConvQuantArg {
QuantArg **quant_args_;
QuantArg *input_quant_args_;
QuantArg *filter_quant_args_;
QuantArg *output_quant_args_;
double *real_multiplier_;
int32_t *left_shift_;
int32_t *right_shift_;
int32_t *quant_multiplier_;
int32_t *out_act_min_;
int32_t *out_act_max_;
size_t input_arg_num_;
size_t filter_arg_num_;
size_t output_arg_num_;
uint8_t asymmetric_;
uint8_t per_channel_;
} ConvQuantArg;
typedef struct ConcatQuantArg {

View File

@ -854,7 +854,7 @@ void Conv3x3Uint8InputTransform(const int16_t *input_data, int16_t *trans_input,
int pad_w = conv_param->pad_w_;
int pad_h = conv_param->pad_h_;
ConvQuantArg quant_arg = conv_param->conv_quant_arg_;
int input_zp = quant_arg.quant_args_[0][0].zp_;
int input_zp = quant_arg.input_quant_args_[0].zp_;
int ic8 = UP_DIV(input_channel, C8NUM);
int input_unit = 4;
@ -1155,11 +1155,11 @@ void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weigh
}
void Conv3x3Uint8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound,
bool w_not_bound, int output_w, int real_num, ConvParameter *conv_param) {
int left_shift = conv_param->conv_quant_arg_.left_shift_[0];
int right_shift = conv_param->conv_quant_arg_.right_shift_[0];
int quant_multiplier = conv_param->conv_quant_arg_.quant_multiplier_[0];
int output_zp = conv_param->conv_quant_arg_.quant_args_[2][0].zp_;
bool w_not_bound, int output_w, int real_num, int oc_start, ConvParameter *conv_param) {
int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_;
int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_;
int32_t *quant_multiplier = conv_param->conv_quant_arg_.quant_multiplier_;
int output_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_;
int out_min = conv_param->conv_quant_arg_.out_act_min_[0];
int out_max = conv_param->conv_quant_arg_.out_act_max_[0];
@ -1202,12 +1202,21 @@ void Conv3x3Uint8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, i
int32x4_t d10 = vaddq_s32(vshrq_n_s32(vaddq_s32(vaddq_s32(t10, t11), t12), 1), bias_ptr);
int32x4_t d11 = vaddq_s32(vshrq_n_s32(vsubq_s32(vsubq_s32(t11, t12), t13), 1), bias_ptr);
int32x4_t out_multiplier = vdupq_n_s32(quant_multiplier);
int32x4_t out_multiplier;
int32x4_t ls;
int32x4_t rs;
if ((conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
out_multiplier = vld1q_s32(quant_multiplier);
ls = vld1q_s32(left_shift);
rs = vld1q_s32(right_shift);
} else {
out_multiplier = vdupq_n_s32(quant_multiplier);
ls = vdupq_n_s32(left_shift);
rs = vdupq_n_s32(right_shift);
}
int32x4_t out_zp = vdupq_n_s32(output_zp);
int32x4_t output_min = vdupq_n_s32(out_min);
int32x4_t output_max = vdupq_n_s32(out_max);
int32x4_t ls = vdupq_n_s32(left_shift);
int32x4_t rs = vdupq_n_s32(right_shift);
d00 = vqshlq_s32(d00, ls);
d00 = vqrdmulhq_s32(d00, out_multiplier);
@ -1261,78 +1270,166 @@ void Conv3x3Uint8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, i
}
}
#else
for (int i = 0; i < C4NUM; i++) {
const int32_t *local_ptr = gemm_out + i;
const int32_t *bias_ptr = bias_data + i;
if ((conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
for (int i = 0; i < C4NUM; i++) {
const int32_t *local_ptr = gemm_out + i;
const int32_t *bias_ptr = bias_data + i;
int32_t s00 = local_ptr[0];
int32_t s01 = (local_ptr + 4)[0];
int32_t s02 = (local_ptr + 8)[0];
int32_t s03 = (local_ptr + 12)[0];
int32_t s00 = local_ptr[0];
int32_t s01 = (local_ptr + 4)[0];
int32_t s02 = (local_ptr + 8)[0];
int32_t s03 = (local_ptr + 12)[0];
int32_t s10 = (local_ptr + 16)[0];
int32_t s11 = (local_ptr + 20)[0];
int32_t s12 = (local_ptr + 24)[0];
int32_t s13 = (local_ptr + 28)[0];
int32_t s10 = (local_ptr + 16)[0];
int32_t s11 = (local_ptr + 20)[0];
int32_t s12 = (local_ptr + 24)[0];
int32_t s13 = (local_ptr + 28)[0];
int32_t s20 = (local_ptr + 32)[0];
int32_t s21 = (local_ptr + 36)[0];
int32_t s22 = (local_ptr + 40)[0];
int32_t s23 = (local_ptr + 44)[0];
int32_t s20 = (local_ptr + 32)[0];
int32_t s21 = (local_ptr + 36)[0];
int32_t s22 = (local_ptr + 40)[0];
int32_t s23 = (local_ptr + 44)[0];
int32_t s30 = (local_ptr + 48)[0];
int32_t s31 = (local_ptr + 52)[0];
int32_t s32 = (local_ptr + 56)[0];
int32_t s33 = (local_ptr + 60)[0];
int32_t s30 = (local_ptr + 48)[0];
int32_t s31 = (local_ptr + 52)[0];
int32_t s32 = (local_ptr + 56)[0];
int32_t s33 = (local_ptr + 60)[0];
int32_t t00 = (s00 + s10 + s20) / 2;
int32_t t01 = (s01 + s11 + s21) / 2;
int32_t t02 = (s02 + s12 + s22) / 2;
int32_t t03 = (s03 + s13 + s23) / 2;
int32_t t00 = (s00 + s10 + s20) / 2;
int32_t t01 = (s01 + s11 + s21) / 2;
int32_t t02 = (s02 + s12 + s22) / 2;
int32_t t03 = (s03 + s13 + s23) / 2;
int32_t t10 = (s10 - s20 - s30) / 2;
int32_t t11 = (s11 - s21 - s31) / 2;
int32_t t12 = (s12 - s22 - s32) / 2;
int32_t t13 = (s13 - s23 - s33) / 2;
int32_t t10 = (s10 - s20 - s30) / 2;
int32_t t11 = (s11 - s21 - s31) / 2;
int32_t t12 = (s12 - s22 - s32) / 2;
int32_t t13 = (s13 - s23 - s33) / 2;
int32_t d00 = (t00 + t01 + t02) / 2 + bias_ptr[0];
int32_t d01 = (t01 - t02 - t03) / 2 + bias_ptr[0];
int32_t d00 = (t00 + t01 + t02) / 2 + bias_ptr[0];
int32_t d01 = (t01 - t02 - t03) / 2 + bias_ptr[0];
int32_t d10 = (t10 + t11 + t12) / 2 + bias_ptr[0];
int32_t d11 = (t11 - t12 - t13) / 2 + bias_ptr[0];
int32_t d10 = (t10 + t11 + t12) / 2 + bias_ptr[0];
int32_t d11 = (t11 - t12 - t13) / 2 + bias_ptr[0];
d00 = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(d00 * (1 << (unsigned int)left_shift), quant_multiplier), -right_shift);
d00 += output_zp;
d00 = d00 > out_min ? d00 : out_min;
d00 = d00 < out_max ? d00 : out_max;
int oc_index = oc_start + i;
d00 = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(d00 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]),
-right_shift[oc_index]);
d00 += output_zp;
d00 = d00 > out_min ? d00 : out_min;
d00 = d00 < out_max ? d00 : out_max;
d01 = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(d01 * (1 << (unsigned int)left_shift), quant_multiplier), -right_shift);
d01 += output_zp;
d01 = d01 > out_min ? d01 : out_min;
d01 = d01 < out_max ? d01 : out_max;
d01 = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(d01 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]),
-right_shift[oc_index]);
d01 += output_zp;
d01 = d01 > out_min ? d01 : out_min;
d01 = d01 < out_max ? d01 : out_max;
d10 = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(d10 * (1 << (unsigned int)left_shift), quant_multiplier), -right_shift);
d10 += output_zp;
d10 = d10 > out_min ? d10 : out_min;
d10 = d10 < out_max ? d10 : out_max;
d10 = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(d10 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]),
-right_shift[oc_index]);
d10 += output_zp;
d10 = d10 > out_min ? d10 : out_min;
d10 = d10 < out_max ? d10 : out_max;
d11 = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(d11 * (1 << (unsigned int)left_shift), quant_multiplier), -right_shift);
d11 += output_zp;
d11 = d11 > out_min ? d11 : out_min;
d11 = d11 < out_max ? d11 : out_max;
d11 = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(d11 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]),
-right_shift[oc_index]);
d11 += output_zp;
d11 = d11 > out_min ? d11 : out_min;
d11 = d11 < out_max ? d11 : out_max;
(output_data + i)[0] = (int8_t)d00;
if (w_not_bound) {
(output_data + i + C4NUM)[0] = (int8_t)d01;
}
if (h_not_bound) {
(output_data + i + output_w * C4NUM)[0] = (int8_t)d10;
(output_data + i)[0] = (int8_t)d00;
if (w_not_bound) {
(output_data + i + output_w * C4NUM + C4NUM)[0] = (int8_t)d11;
(output_data + i + C4NUM)[0] = (int8_t)d01;
}
if (h_not_bound) {
(output_data + i + output_w * C4NUM)[0] = (int8_t)d10;
if (w_not_bound) {
(output_data + i + output_w * C4NUM + C4NUM)[0] = (int8_t)d11;
}
}
}
} else {
for (int i = 0; i < C4NUM; i++) {
const int32_t *local_ptr = gemm_out + i;
const int32_t *bias_ptr = bias_data + i;
int32_t s00 = local_ptr[0];
int32_t s01 = (local_ptr + 4)[0];
int32_t s02 = (local_ptr + 8)[0];
int32_t s03 = (local_ptr + 12)[0];
int32_t s10 = (local_ptr + 16)[0];
int32_t s11 = (local_ptr + 20)[0];
int32_t s12 = (local_ptr + 24)[0];
int32_t s13 = (local_ptr + 28)[0];
int32_t s20 = (local_ptr + 32)[0];
int32_t s21 = (local_ptr + 36)[0];
int32_t s22 = (local_ptr + 40)[0];
int32_t s23 = (local_ptr + 44)[0];
int32_t s30 = (local_ptr + 48)[0];
int32_t s31 = (local_ptr + 52)[0];
int32_t s32 = (local_ptr + 56)[0];
int32_t s33 = (local_ptr + 60)[0];
int32_t t00 = (s00 + s10 + s20) / 2;
int32_t t01 = (s01 + s11 + s21) / 2;
int32_t t02 = (s02 + s12 + s22) / 2;
int32_t t03 = (s03 + s13 + s23) / 2;
int32_t t10 = (s10 - s20 - s30) / 2;
int32_t t11 = (s11 - s21 - s31) / 2;
int32_t t12 = (s12 - s22 - s32) / 2;
int32_t t13 = (s13 - s23 - s33) / 2;
int32_t d00 = (t00 + t01 + t02) / 2 + bias_ptr[0];
int32_t d01 = (t01 - t02 - t03) / 2 + bias_ptr[0];
int32_t d10 = (t10 + t11 + t12) / 2 + bias_ptr[0];
int32_t d11 = (t11 - t12 - t13) / 2 + bias_ptr[0];
d00 = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(d00 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]),
-right_shift[0]);
d00 += output_zp;
d00 = d00 > out_min ? d00 : out_min;
d00 = d00 < out_max ? d00 : out_max;
d01 = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(d01 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]),
-right_shift[0]);
d01 += output_zp;
d01 = d01 > out_min ? d01 : out_min;
d01 = d01 < out_max ? d01 : out_max;
d10 = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(d10 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]),
-right_shift[0]);
d10 += output_zp;
d10 = d10 > out_min ? d10 : out_min;
d10 = d10 < out_max ? d10 : out_max;
d11 = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(d11 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]),
-right_shift[0]);
d11 += output_zp;
d11 = d11 > out_min ? d11 : out_min;
d11 = d11 < out_max ? d11 : out_max;
(output_data + i)[0] = (int8_t)d00;
if (w_not_bound) {
(output_data + i + C4NUM)[0] = (int8_t)d01;
}
if (h_not_bound) {
(output_data + i + output_w * C4NUM)[0] = (int8_t)d10;
if (w_not_bound) {
(output_data + i + output_w * C4NUM + C4NUM)[0] = (int8_t)d11;
}
}
}
}
@ -1364,7 +1461,8 @@ void Conv3x3Uint8OutputTransform(const int32_t *gemm_out, int8_t *out_data, cons
int real_num = (output_channel - j * C4NUM) < C4NUM ? (output_channel - j * C4NUM) : C4NUM;
bool w_not_bound = out_w_index * OUPUT_UNIT + 1 < output_w;
bool h_not_bound = out_h_index * OUPUT_UNIT + 1 < output_h;
Conv3x3Uint8OutputUnit(src_ptr, bias_ptr, dst_ptr, h_not_bound, w_not_bound, output_w, real_num, conv_param);
Conv3x3Uint8OutputUnit(src_ptr, bias_ptr, dst_ptr, h_not_bound, w_not_bound, output_w, real_num, j * C4NUM,
conv_param);
}
}
}

View File

@ -65,7 +65,7 @@ void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weigh
int kernel_plane);
void Conv3x3Uint8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound,
bool w_not_bound, int output_w, int real_num, ConvParameter *conv_param);
bool w_not_bound, int output_w, int real_num, int oc_start, ConvParameter *conv_param);
void Conv3x3Uint8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const int32_t *bias_data, int start_index,
int real_cal_num, int out_w_block, ConvParameter *conv_param);