forked from mindspore-Ecosystem/mindspore
Modify the quantization operation of softmax op
This commit is contained in:
parent
1ca715c7e7
commit
29c04de893
|
@ -37,17 +37,24 @@ int SoftmaxInt8CPUKernel::Init() {
|
|||
|
||||
auto in_quant_args = input_tensor->GetQuantParams();
|
||||
quant_params_.in_quant_args_.scale_ = in_quant_args.front().scale;
|
||||
quant_params_.in_quant_args_.zp_ = in_quant_args.front().zeroPoint;
|
||||
quant_params_.in_quant_args_.zp_ = -in_quant_args.front().zeroPoint;
|
||||
|
||||
auto *out_tensor = out_tensors_.at(kOutputIndex);
|
||||
MS_ASSERT(out_tensor);
|
||||
|
||||
auto out_quant_args = out_tensor->GetQuantParams();
|
||||
quant_params_.out_quant_arg_.scale_ = out_quant_args.front().scale;
|
||||
quant_params_.out_quant_arg_.zp_ = out_quant_args.front().zeroPoint;
|
||||
quant_params_.out_quant_arg_.zp_ = -out_quant_args.front().zeroPoint;
|
||||
quant_params_.output_activation_min_ = std::numeric_limits<int8_t>::min();
|
||||
quant_params_.output_activation_max_ = std::numeric_limits<int8_t>::max();
|
||||
|
||||
const double input_real_multiplier =
|
||||
MSMIN(quant_params_.in_quant_args_.scale_ * (1 << (unsigned int)(31 - 5)), (1ll << 31) - 1.0);
|
||||
int right_shift = 0;
|
||||
QuantizeMultiplierSmallerThanOne(input_real_multiplier, &quant_params_.output_multiplier_, &right_shift);
|
||||
quant_params_.shift_left_ = right_shift < 0 ? -right_shift : 0;
|
||||
quant_params_.shift_right_ = right_shift > 0 ? right_shift : 0;
|
||||
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -72,12 +79,12 @@ int SoftmaxInt8CPUKernel::ReSize() {
|
|||
return ret;
|
||||
}
|
||||
FreeTmpBuffer();
|
||||
exp_data_ = reinterpret_cast<float *>(malloc(softmax_param_->element_size_ * sizeof(float)));
|
||||
exp_data_ = reinterpret_cast<int *>(malloc(softmax_param_->element_size_ * sizeof(int)));
|
||||
int inner_size = 1;
|
||||
for (int i = softmax_param_->axis_ + 1; i < softmax_param_->n_dim_; i++) {
|
||||
inner_size *= softmax_param_->input_shape_[i];
|
||||
}
|
||||
sum_data_ = reinterpret_cast<float *>(malloc(inner_size * sizeof(float)));
|
||||
sum_data_ = reinterpret_cast<int *>(malloc(inner_size * sizeof(int)));
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -125,12 +132,7 @@ int SoftmaxInt8CPUKernel::Run() {
|
|||
MS_LOG(ERROR) << "Prepare fail!ret: " << ret;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto input_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(0)->Data());
|
||||
int ele_size = softmax_param_->element_size_;
|
||||
for (int i = 0; i < ele_size; i++) {
|
||||
float input_scaled = ((input_ptr[i] - quant_params_.in_quant_args_.zp_) * quant_params_.in_quant_args_.scale_);
|
||||
exp_data_[i] = exp(input_scaled);
|
||||
}
|
||||
|
||||
int error_code = LiteBackendParallelLaunch(SoftmaxRun, this, thread_count_);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "Softmax function error error_code[" << error_code << "]";
|
||||
|
|
|
@ -37,8 +37,8 @@ class SoftmaxInt8CPUKernel : public SoftmaxBaseCPUKernel {
|
|||
|
||||
private:
|
||||
void FreeTmpBuffer();
|
||||
float *sum_data_ = nullptr;
|
||||
float *exp_data_ = nullptr;
|
||||
int *sum_data_ = nullptr;
|
||||
int *exp_data_ = nullptr;
|
||||
SoftmaxQuantArg quant_params_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -16,17 +16,17 @@
|
|||
|
||||
#include "nnacl/int8/softmax_int8.h"
|
||||
#include <math.h>
|
||||
#include "nnacl/quantization/fixed_point.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, float *exp_data, float *sum_data,
|
||||
int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, int *exp_data, int *sum_data,
|
||||
SoftmaxQuantArg quant_param, SoftmaxParameter *parameter) {
|
||||
int32_t axis = parameter->axis_;
|
||||
int n_dim = parameter->n_dim_;
|
||||
int *input_shape = parameter->input_shape_;
|
||||
int axis_shape_size = input_shape[axis];
|
||||
|
||||
double output_scale = quant_param.out_quant_arg_.scale_;
|
||||
int32_t output_zp = quant_param.out_quant_arg_.zp_;
|
||||
|
||||
int inner_size = 1;
|
||||
for (int i = axis + 1; i < n_dim; i++) {
|
||||
inner_size *= input_shape[i];
|
||||
|
@ -34,22 +34,37 @@ int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, float *e
|
|||
|
||||
for (int o = 0; o < count; o++) {
|
||||
int outter_offset = o * axis_shape_size * inner_size;
|
||||
for (int i = 0; i < inner_size; i++) {
|
||||
float sum = 0;
|
||||
for (int j = 0; j < axis_shape_size; j++) {
|
||||
int axis_offset = outter_offset + i + j * inner_size;
|
||||
sum += exp_data[axis_offset];
|
||||
|
||||
for (int c = 0; c < inner_size; c++) {
|
||||
int8_t max_row = quant_param.output_activation_min_;
|
||||
for (int i = 0; i < axis_shape_size; ++i) {
|
||||
int axis_offset = outter_offset + c + i * inner_size;
|
||||
max_row = MSMAX(max_row, input_ptr[axis_offset]);
|
||||
}
|
||||
sum_data[i] = sum;
|
||||
|
||||
int32_t exp_sum = 0;
|
||||
for (int i = 0; i < axis_shape_size; ++i) {
|
||||
int axis_offset = outter_offset + c + i * inner_size;
|
||||
const int32_t input_val = input_ptr[axis_offset] - max_row;
|
||||
const int32_t input_scaled = SaturatingRoundingDoublingHighMul(
|
||||
input_val * (1 << (unsigned int)quant_param.shift_left_), quant_param.output_multiplier_);
|
||||
int exp_val = exp_on_negative_values(input_scaled, 5);
|
||||
exp_data[axis_offset] = exp_val;
|
||||
exp_sum = exp_sum + Rescale(exp_val, 0, 12);
|
||||
}
|
||||
sum_data[c] = exp_sum;
|
||||
}
|
||||
for (int j = 0; j < axis_shape_size; j++) {
|
||||
int axis_offset = outter_offset + j * inner_size;
|
||||
for (int i = 0; i < inner_size; i++) {
|
||||
int inner_offset = axis_offset + i;
|
||||
float real_output = exp_data[inner_offset] / sum_data[i];
|
||||
int32_t output_scaled = round(real_output / output_scale) + output_zp;
|
||||
output_ptr[inner_offset] =
|
||||
MSMAX(quant_param.output_activation_min_, MSMIN(quant_param.output_activation_max_, output_scaled));
|
||||
for (int i = 0; i < axis_shape_size; ++i) {
|
||||
int axis_offset = outter_offset + i * inner_size;
|
||||
for (int c = 0; c < inner_size; ++c) {
|
||||
int num_bits_over_unit;
|
||||
int shifted_scale = ComputerReciproal(sum_data[c], 12, &num_bits_over_unit);
|
||||
int unsat_output = RoundingDivideByPOT(
|
||||
SaturatingRoundingDoublingHighMul(shifted_scale, exp_data[axis_offset + c]), num_bits_over_unit + 31 - 8);
|
||||
|
||||
int raw_output = unsat_output + quant_param.output_activation_min_;
|
||||
output_ptr[axis_offset + c] =
|
||||
(int8_t)MSMAX(quant_param.output_activation_min_, MSMIN(raw_output, quant_param.output_activation_max_));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, float *exp_data, float *sum_data,
|
||||
int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, int *exp_data, int *sum_data,
|
||||
SoftmaxQuantArg quant_param, SoftmaxParameter *parameter);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -86,24 +86,22 @@ int32_t MaskNonZero(int32_t a) {
|
|||
return a ? BitNot(zreo) : zreo;
|
||||
}
|
||||
|
||||
int SaturatingRoundingMultiplyByPOT(int32_t x, int Exponent) {
|
||||
int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0);
|
||||
if (ExponentSign == 0) {
|
||||
return x;
|
||||
} else if (ExponentSign == 1) {
|
||||
static inline int SaturatingRoundingMultiplyByPOT(int32_t x, int Exponent) {
|
||||
if (Exponent > 0) {
|
||||
const int min = INT32_MIN;
|
||||
const int max = INT32_MAX;
|
||||
const int thresold = ((1 << (uint32_t)(31 - Exponent)) - 1);
|
||||
const int scalar_int_bits = 8 * sizeof(int32_t);
|
||||
const int thresold = ((1 << (uint32_t)(scalar_int_bits - 1 - Exponent)) - 1);
|
||||
const int postive_mask = MaskNonZero(x > thresold);
|
||||
const int negative_mask = MaskNonZero(x < -thresold);
|
||||
int result = x << Exponent;
|
||||
int result = x * ((int32_t)(1) << (uint32_t)Exponent);
|
||||
result = SelectUsingMask(postive_mask, max, result);
|
||||
result = SelectUsingMask(negative_mask, min, result);
|
||||
return result;
|
||||
} else if (ExponentSign == -1) {
|
||||
} else if (Exponent < 0) {
|
||||
return RoundingDivideByPOT(x, -Exponent);
|
||||
} else {
|
||||
return 0;
|
||||
return x;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -113,7 +111,7 @@ int32_t Rescale(int x, int kIntegerBitsSrc, int kIntegerBitsDst) {
|
|||
return result;
|
||||
}
|
||||
|
||||
static int32_t one_over_one_plus_x_for_x_in_0_1(int32_t a) {
|
||||
int32_t one_over_one_plus_x_for_x_in_0_1(int32_t a) {
|
||||
int one = FixedPoint_One(0, FractionsBits(0));
|
||||
int half_denominator = RoundingHalfSum(a, one);
|
||||
const int constant_48_over_17 = 1515870810;
|
||||
|
@ -159,6 +157,71 @@ int32_t ComputerReciproal(int32_t x, int x_digits, int *recip_shift) {
|
|||
const int32_t shifted_scaled = one_over_one_plus_x_for_x_in_0_1(shifted_minus_one);
|
||||
return shifted_scaled;
|
||||
}
|
||||
int ConstantPOT(int fractional_bits, int exponent) {
|
||||
int offset = fractional_bits + exponent;
|
||||
return (1 << (uint32_t)offset);
|
||||
}
|
||||
|
||||
int32_t MaskIfNonZero(int32_t a) { return a ? BitNot(0) : 0; }
|
||||
|
||||
int32_t MaskIfZero(int32_t a) { return MaskIfNonZero(!a); }
|
||||
|
||||
int32_t MaskIfLessThan(int32_t a, int32_t b) { return MaskIfNonZero((a < b)); }
|
||||
|
||||
int exp_on_interval_between_negative_one_quarter_and_0_excl(int a) {
|
||||
const int constant_term = 1895147668;
|
||||
const int constant_1_over_3 = 715827883;
|
||||
// We're evaluating a Taylor expansion around -1/8, so we do the change of
|
||||
// variable: x = a + 1/8.
|
||||
// In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28.
|
||||
int kFractionalBits = FractionsBits(0);
|
||||
int x = a + ConstantPOT(kFractionalBits, -3);
|
||||
int x2 = SaturatingRoundingDoublingHighMul(x, x);
|
||||
int x3 = SaturatingRoundingDoublingHighMul(x2, x);
|
||||
int x4 = SaturatingRoundingDoublingHighMul(x2, x2);
|
||||
int x4_over_4 = SaturatingRoundingMultiplyByPOT(x4, -2);
|
||||
int x4_over_24_plus_x3_over_6_plus_x2_over_2 =
|
||||
SaturatingRoundingMultiplyByPOT((SaturatingRoundingDoublingHighMul((x4_over_4 + x3), constant_1_over_3) + x2), -1);
|
||||
return constant_term +
|
||||
SaturatingRoundingDoublingHighMul(constant_term, (x + x4_over_24_plus_x3_over_6_plus_x2_over_2));
|
||||
}
|
||||
|
||||
int exp_on_negative_values(int a, const int tIntegerBits) {
|
||||
int kIntegerBits = tIntegerBits;
|
||||
int kFractionalBits = FractionsBits(tIntegerBits);
|
||||
const int kOneQuarter = ConstantPOT(kFractionalBits, -2);
|
||||
int mask = kOneQuarter - 1;
|
||||
int a_mod_quarter_minus_one_quarter = ((unsigned)(a)&mask) - kOneQuarter;
|
||||
int result =
|
||||
exp_on_interval_between_negative_one_quarter_and_0_excl(Rescale(a_mod_quarter_minus_one_quarter, tIntegerBits, 0));
|
||||
int remainder = a_mod_quarter_minus_one_quarter - a;
|
||||
|
||||
#define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier) \
|
||||
if (kIntegerBits > Exponent) { \
|
||||
const int kMultiplier = FixedPointMultiplier; \
|
||||
int kShiftAmount = kIntegerBits > Exponent ? kFractionalBits + Exponent : 0; \
|
||||
result = SelectUsingMask(MaskIfNonZero(BitAnd(remainder, (1 << (uint32_t)kShiftAmount))), \
|
||||
SaturatingRoundingDoublingHighMul(result, kMultiplier), result); \
|
||||
}
|
||||
GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947);
|
||||
GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674);
|
||||
GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084);
|
||||
GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308);
|
||||
GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535);
|
||||
GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401);
|
||||
GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242);
|
||||
#undef GEMMLOWP_EXP_BARREL_SHIFTER
|
||||
|
||||
int clampB = kIntegerBits > 5 ? 36 - kIntegerBits : 0;
|
||||
if (kIntegerBits > 5) {
|
||||
const int clamp = -(1 << (uint32_t)clampB);
|
||||
result = SelectUsingMask(MaskIfLessThan(a, clamp), 0, result);
|
||||
}
|
||||
|
||||
result = SelectUsingMask(MaskIfZero(a), FixedPoint_One(0, kFractionalBits), result);
|
||||
return result;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
int32x4_t RoundingDivideByPOTInt32x4(int32x4_t x, int exponent) {
|
||||
const int32x4_t shift_vec = vdupq_n_s32(-exponent);
|
||||
|
|
|
@ -60,11 +60,9 @@ int SelectUsingMask(int mask, int bound, int val);
|
|||
|
||||
int32_t MaskNonZero(int32_t a);
|
||||
|
||||
int SaturatingRoundingMultiplyByPOT(int32_t x, int Exponent);
|
||||
|
||||
int32_t Rescale(int x, int kIntegerBitsSrc, int kIntegerBitsDst);
|
||||
|
||||
static int32_t one_over_one_plus_x_for_x_in_0_1(int32_t a);
|
||||
int32_t one_over_one_plus_x_for_x_in_0_1(int32_t a);
|
||||
|
||||
int CountLeadingZeroBits(uint32_t x);
|
||||
|
||||
|
@ -72,6 +70,18 @@ int CountLeadingSignBits(int32_t x);
|
|||
|
||||
int32_t ComputerReciproal(int32_t x, int x_digits, int *recip_shift);
|
||||
|
||||
int exp_on_negative_values(int a, const int tIntegerBits);
|
||||
|
||||
int ConstantPOT(int fractional_bits, int exponent);
|
||||
|
||||
int32_t MaskIfNonZero(int32_t a);
|
||||
|
||||
int32_t MaskIfZero(int32_t a);
|
||||
|
||||
int32_t MaskIfLessThan(int32_t a, int32_t b);
|
||||
|
||||
int exp_on_interval_between_negative_one_quarter_and_0_excl(int a);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -80,9 +80,8 @@ TEST_F(TestSoftmaxInt8, SoftmaxInt8) {
|
|||
auto output_tensor_shape = output0_tensor.shape();
|
||||
kernel->Run();
|
||||
|
||||
std::vector<int8_t> except_result = {-126, -126, -124, -124, -123, -124, -116, -116, 121, 121, 111, 111,
|
||||
-127, -127, -127, -127, -59, -59, -61, -59, 57, 57, 59, 57};
|
||||
|
||||
std::vector<int8_t> except_result = {-126, -126, -124, -124, -123, -124, -116, -116, 122, 122, 112, 112,
|
||||
-127, -127, -127, -127, -59, -59, -61, -59, 58, 58, 59, 58};
|
||||
CompareOutputData(output.data(), except_result.data(), input.size(), 0.000001);
|
||||
|
||||
input0_tensor.SetData(nullptr);
|
||||
|
|
Loading…
Reference in New Issue