forked from mindspore-Ecosystem/mindspore
!9182 [MSLITE] add const node bug
From: @ling_qiao_min Reviewed-by: @zhanghaibo5,@zhang_xue_tong Signed-off-by: @zhang_xue_tong
This commit is contained in:
commit
5c5e1520cd
|
@ -21,23 +21,23 @@
|
|||
#include "nnacl/quantization/fixed_point.h"
|
||||
|
||||
void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int size, AddQuantParameter *params) {
|
||||
int in0_left_shift = (1 << params->left_shift_) * (1 << params->in0_left_shift_);
|
||||
int in1_left_shift = (1 << params->left_shift_) * (1 << params->in1_left_shift_);
|
||||
int in0_left_shift = (1 << params->left_shift_) * (1 << params->in0_args_.left_shift_);
|
||||
int in1_left_shift = (1 << params->left_shift_) * (1 << params->in1_args_.left_shift_);
|
||||
int index = 0;
|
||||
|
||||
#ifdef ENABLE_ARM
|
||||
const int8x16_t min_vec = vdupq_n_s8(params->min_);
|
||||
const int8x16_t max_vac = vdupq_n_s8(params->max_);
|
||||
|
||||
const int16x8_t in0_zp_vec = vdupq_n_s16(params->in0_zp_);
|
||||
const int16x8_t in1_zp_vec = vdupq_n_s16(params->in1_zp_);
|
||||
const int16x8_t in0_zp_vec = vdupq_n_s16(params->in0_args_.zp_);
|
||||
const int16x8_t in1_zp_vec = vdupq_n_s16(params->in1_args_.zp_);
|
||||
const int16x8_t out_zp_vec = vdupq_n_s16(params->out_zp_);
|
||||
|
||||
const int32x4_t in0_left_vec = vdupq_n_s32(in0_left_shift);
|
||||
const int32x4_t in1_left_vec = vdupq_n_s32(in1_left_shift);
|
||||
|
||||
const int32x4_t in0_right_vec = vdupq_n_s32(-params->in0_right_shift_);
|
||||
const int32x4_t in1_right_vec = vdupq_n_s32(-params->in1_right_shift_);
|
||||
const int32x4_t in0_right_vec = vdupq_n_s32(-params->in0_args_.right_shift_);
|
||||
const int32x4_t in1_right_vec = vdupq_n_s32(-params->in1_args_.right_shift_);
|
||||
|
||||
const int32x4_t out_left_vec = vdupq_n_s32(params->out_left_shift_);
|
||||
const int32x4_t out_right_vec = vdupq_n_s32(-params->out_right_shift_);
|
||||
|
@ -76,14 +76,14 @@ void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int siz
|
|||
in1_4 = vmulq_s32(in1_4, in1_left_vec);
|
||||
|
||||
// Apply the fixed-point part of the multiplier.
|
||||
in0_1 = vqrdmulhq_n_s32(in0_1, params->in0_multiplier_);
|
||||
in0_2 = vqrdmulhq_n_s32(in0_2, params->in0_multiplier_);
|
||||
in0_3 = vqrdmulhq_n_s32(in0_3, params->in0_multiplier_);
|
||||
in0_4 = vqrdmulhq_n_s32(in0_4, params->in0_multiplier_);
|
||||
in1_1 = vqrdmulhq_n_s32(in1_1, params->in1_multiplier_);
|
||||
in1_2 = vqrdmulhq_n_s32(in1_2, params->in1_multiplier_);
|
||||
in1_3 = vqrdmulhq_n_s32(in1_3, params->in1_multiplier_);
|
||||
in1_4 = vqrdmulhq_n_s32(in1_4, params->in1_multiplier_);
|
||||
in0_1 = vqrdmulhq_n_s32(in0_1, params->in0_args_.multiplier_);
|
||||
in0_2 = vqrdmulhq_n_s32(in0_2, params->in0_args_.multiplier_);
|
||||
in0_3 = vqrdmulhq_n_s32(in0_3, params->in0_args_.multiplier_);
|
||||
in0_4 = vqrdmulhq_n_s32(in0_4, params->in0_args_.multiplier_);
|
||||
in1_1 = vqrdmulhq_n_s32(in1_1, params->in1_args_.multiplier_);
|
||||
in1_2 = vqrdmulhq_n_s32(in1_2, params->in1_args_.multiplier_);
|
||||
in1_3 = vqrdmulhq_n_s32(in1_3, params->in1_args_.multiplier_);
|
||||
in1_4 = vqrdmulhq_n_s32(in1_4, params->in1_args_.multiplier_);
|
||||
|
||||
// Apply right shift
|
||||
in0_1 = vqaddq_s32(in0_1, vshrq_n_s32(vandq_s32(in0_1, in0_right_vec), 31));
|
||||
|
@ -149,10 +149,12 @@ void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int siz
|
|||
#endif
|
||||
|
||||
for (; index < size; index++) {
|
||||
const int32_t in0_left = (input0[index] + params->in0_zp_) * in0_left_shift;
|
||||
const int32_t in1_left = (input1[index] + params->in1_zp_) * in1_left_shift;
|
||||
const int32_t in0 = MultiplyByMultiplierAndRightShift(in0_left, params->in0_multiplier_, params->in0_right_shift_);
|
||||
const int32_t in1 = MultiplyByMultiplierAndRightShift(in1_left, params->in1_multiplier_, params->in1_right_shift_);
|
||||
const int32_t in0_left = (input0[index] + params->in0_args_.zp_) * in0_left_shift;
|
||||
const int32_t in1_left = (input1[index] + params->in1_args_.zp_) * in1_left_shift;
|
||||
const int32_t in0 =
|
||||
MultiplyByMultiplierAndRightShift(in0_left, params->in0_args_.multiplier_, params->in0_args_.right_shift_);
|
||||
const int32_t in1 =
|
||||
MultiplyByMultiplierAndRightShift(in1_left, params->in1_args_.multiplier_, params->in1_args_.right_shift_);
|
||||
|
||||
int32_t out = MultiplyByQuantizedMultiplier(in0 + in1, params->out_multiplier_, params->out_left_shift_,
|
||||
-params->out_right_shift_);
|
||||
|
@ -162,110 +164,116 @@ void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int siz
|
|||
return;
|
||||
}
|
||||
|
||||
void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, AddQuantParameter *params) {
|
||||
int in0_left_shift = (1 << params->left_shift_) * (1 << params->in0_left_shift_);
|
||||
int in1_left_shift = (1 << params->left_shift_) * (1 << params->in1_left_shift_);
|
||||
void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, AddQuantParameter *params,
|
||||
AddQuantQrgs *ptr_args, AddQuantQrgs *ele_args) {
|
||||
int ptr_left_shift = (1 << params->left_shift_) * (1 << ptr_args->left_shift_);
|
||||
int ele_left_shift = (1 << params->left_shift_) * (1 << ele_args->left_shift_);
|
||||
int index = 0;
|
||||
|
||||
#ifdef ENABLE_ARM
|
||||
const int8x16_t in1_src = vdupq_n_s8(element_in);
|
||||
|
||||
/* const value init */
|
||||
const int8x16_t min_vec = vdupq_n_s8(params->min_);
|
||||
const int8x16_t max_vac = vdupq_n_s8(params->max_);
|
||||
|
||||
const int16x8_t in0_zp_vec = vdupq_n_s16(params->in0_zp_);
|
||||
const int16x8_t in1_zp_vec = vdupq_n_s16(params->in1_zp_);
|
||||
const int16x8_t ptr_zp_vec = vdupq_n_s16(ptr_args->zp_);
|
||||
const int16x8_t ele_zp_vec = vdupq_n_s16(ele_args->zp_);
|
||||
const int16x8_t out_zp_vec = vdupq_n_s16(params->out_zp_);
|
||||
|
||||
const int32x4_t in0_left_vec = vdupq_n_s32(in0_left_shift);
|
||||
const int32x4_t in1_left_vec = vdupq_n_s32(in1_left_shift);
|
||||
const int32x4_t ptr_left_vec = vdupq_n_s32(ptr_left_shift);
|
||||
const int32x4_t ele_left_vec = vdupq_n_s32(ele_left_shift);
|
||||
|
||||
const int32x4_t in0_right_vec = vdupq_n_s32(-params->in0_right_shift_);
|
||||
const int32x4_t in1_right_vec = vdupq_n_s32(-params->in1_right_shift_);
|
||||
const int32x4_t ptr_right_vec = vdupq_n_s32(-ptr_args->right_shift_);
|
||||
const int32x4_t ele_right_vec = vdupq_n_s32(-ptr_args->right_shift_);
|
||||
|
||||
const int32x4_t out_left_vec = vdupq_n_s32(params->out_left_shift_);
|
||||
const int32x4_t out_right_vec = vdupq_n_s32(-params->out_right_shift_);
|
||||
|
||||
/* deal with const node */
|
||||
const int8x16_t ele_src = vdupq_n_s8(element_in);
|
||||
const int16x8_t ele_s16_low = vmovl_s8(vget_low_s8(ele_src));
|
||||
const int16x8_t ele_s16_high = vmovl_s8(vget_high_s8(ele_src));
|
||||
const int16x8_t ele_zp_low = vaddq_s16(ele_s16_low, ele_zp_vec);
|
||||
const int16x8_t ele_zp_high = vaddq_s16(ele_s16_high, ele_zp_vec);
|
||||
int32x4_t ele1 = vmovl_s16(vget_low_s16(ele_zp_low));
|
||||
int32x4_t ele2 = vmovl_s16(vget_high_s16(ele_zp_low));
|
||||
int32x4_t ele3 = vmovl_s16(vget_low_s16(ele_zp_high));
|
||||
int32x4_t ele4 = vmovl_s16(vget_high_s16(ele_zp_high));
|
||||
// Apply left shift
|
||||
ele1 = vmulq_s32(ele1, ele_left_vec);
|
||||
ele2 = vmulq_s32(ele2, ele_left_vec);
|
||||
ele3 = vmulq_s32(ele3, ele_left_vec);
|
||||
ele4 = vmulq_s32(ele4, ele_left_vec);
|
||||
// Apply the fixed-point part of the multiplier.
|
||||
ele1 = vqrdmulhq_n_s32(ele1, ele_args->multiplier_);
|
||||
ele2 = vqrdmulhq_n_s32(ele2, ele_args->multiplier_);
|
||||
ele3 = vqrdmulhq_n_s32(ele3, ele_args->multiplier_);
|
||||
ele4 = vqrdmulhq_n_s32(ele4, ele_args->multiplier_);
|
||||
// Apply right shift
|
||||
ele1 = vqaddq_s32(ele1, vshrq_n_s32(vandq_s32(ele1, ele_right_vec), 31));
|
||||
ele2 = vqaddq_s32(ele2, vshrq_n_s32(vandq_s32(ele2, ele_right_vec), 31));
|
||||
ele3 = vqaddq_s32(ele3, vshrq_n_s32(vandq_s32(ele3, ele_right_vec), 31));
|
||||
ele4 = vqaddq_s32(ele4, vshrq_n_s32(vandq_s32(ele4, ele_right_vec), 31));
|
||||
ele1 = vrshlq_s32(ele1, ele_right_vec);
|
||||
ele2 = vrshlq_s32(ele2, ele_right_vec);
|
||||
ele3 = vrshlq_s32(ele3, ele_right_vec);
|
||||
ele4 = vrshlq_s32(ele4, ele_right_vec);
|
||||
|
||||
for (; index <= size - 16; index += 16) {
|
||||
const int8x16_t in0_src = vld1q_s8(ptr_in + index);
|
||||
const int8x16_t ptr_src = vld1q_s8(ptr_in + index);
|
||||
|
||||
const int16x8_t in0_s16_low = vmovl_s8(vget_low_s8(in0_src));
|
||||
const int16x8_t in0_s16_high = vmovl_s8(vget_high_s8(in0_src));
|
||||
const int16x8_t in1_s16_low = vmovl_s8(vget_low_s8(in1_src));
|
||||
const int16x8_t in1_s16_high = vmovl_s8(vget_high_s8(in1_src));
|
||||
const int16x8_t ptr_s16_low = vmovl_s8(vget_low_s8(ptr_src));
|
||||
const int16x8_t ptr_s16_high = vmovl_s8(vget_high_s8(ptr_src));
|
||||
|
||||
const int16x8_t in0_zp_low = vaddq_s16(in0_s16_low, in0_zp_vec);
|
||||
const int16x8_t in0_zp_high = vaddq_s16(in0_s16_high, in0_zp_vec);
|
||||
const int16x8_t in1_zp_low = vaddq_s16(in1_s16_low, in1_zp_vec);
|
||||
const int16x8_t in1_zp_high = vaddq_s16(in1_s16_high, in1_zp_vec);
|
||||
const int16x8_t ptr_zp_low = vaddq_s16(ptr_s16_low, ptr_zp_vec);
|
||||
const int16x8_t ptr_zp_high = vaddq_s16(ptr_s16_high, ptr_zp_vec);
|
||||
|
||||
int32x4_t in0_1 = vmovl_s16(vget_low_s16(in0_zp_low));
|
||||
int32x4_t in0_2 = vmovl_s16(vget_high_s16(in0_zp_low));
|
||||
int32x4_t in0_3 = vmovl_s16(vget_low_s16(in0_zp_high));
|
||||
int32x4_t in0_4 = vmovl_s16(vget_high_s16(in0_zp_high));
|
||||
int32x4_t in1_1 = vmovl_s16(vget_low_s16(in1_zp_low));
|
||||
int32x4_t in1_2 = vmovl_s16(vget_high_s16(in1_zp_low));
|
||||
int32x4_t in1_3 = vmovl_s16(vget_low_s16(in1_zp_high));
|
||||
int32x4_t in1_4 = vmovl_s16(vget_high_s16(in1_zp_high));
|
||||
int32x4_t ptr1 = vmovl_s16(vget_low_s16(ptr_zp_low));
|
||||
int32x4_t ptr2 = vmovl_s16(vget_high_s16(ptr_zp_low));
|
||||
int32x4_t ptr3 = vmovl_s16(vget_low_s16(ptr_zp_high));
|
||||
int32x4_t ptr4 = vmovl_s16(vget_high_s16(ptr_zp_high));
|
||||
|
||||
// Apply left shift
|
||||
in0_1 = vmulq_s32(in0_1, in0_left_vec);
|
||||
in0_2 = vmulq_s32(in0_2, in0_left_vec);
|
||||
in0_3 = vmulq_s32(in0_3, in0_left_vec);
|
||||
in0_4 = vmulq_s32(in0_4, in0_left_vec);
|
||||
in1_1 = vmulq_s32(in1_1, in1_left_vec);
|
||||
in1_2 = vmulq_s32(in1_2, in1_left_vec);
|
||||
in1_3 = vmulq_s32(in1_3, in1_left_vec);
|
||||
in1_4 = vmulq_s32(in1_4, in1_left_vec);
|
||||
ptr1 = vmulq_s32(ptr1, ptr_left_vec);
|
||||
ptr2 = vmulq_s32(ptr2, ptr_left_vec);
|
||||
ptr3 = vmulq_s32(ptr3, ptr_left_vec);
|
||||
ptr4 = vmulq_s32(ptr4, ptr_left_vec);
|
||||
|
||||
// Apply the fixed-point part of the multiplier.
|
||||
in0_1 = vqrdmulhq_n_s32(in0_1, params->in0_multiplier_);
|
||||
in0_2 = vqrdmulhq_n_s32(in0_2, params->in0_multiplier_);
|
||||
in0_3 = vqrdmulhq_n_s32(in0_3, params->in0_multiplier_);
|
||||
in0_4 = vqrdmulhq_n_s32(in0_4, params->in0_multiplier_);
|
||||
in1_1 = vqrdmulhq_n_s32(in1_1, params->in1_multiplier_);
|
||||
in1_2 = vqrdmulhq_n_s32(in1_2, params->in1_multiplier_);
|
||||
in1_3 = vqrdmulhq_n_s32(in1_3, params->in1_multiplier_);
|
||||
in1_4 = vqrdmulhq_n_s32(in1_4, params->in1_multiplier_);
|
||||
ptr1 = vqrdmulhq_n_s32(ptr1, ptr_args->multiplier_);
|
||||
ptr2 = vqrdmulhq_n_s32(ptr2, ptr_args->multiplier_);
|
||||
ptr3 = vqrdmulhq_n_s32(ptr3, ptr_args->multiplier_);
|
||||
ptr4 = vqrdmulhq_n_s32(ptr4, ptr_args->multiplier_);
|
||||
|
||||
// Apply right shift
|
||||
in0_1 = vqaddq_s32(in0_1, vshrq_n_s32(vandq_s32(in0_1, in0_right_vec), 31));
|
||||
in0_2 = vqaddq_s32(in0_2, vshrq_n_s32(vandq_s32(in0_2, in0_right_vec), 31));
|
||||
in0_3 = vqaddq_s32(in0_3, vshrq_n_s32(vandq_s32(in0_3, in0_right_vec), 31));
|
||||
in0_4 = vqaddq_s32(in0_4, vshrq_n_s32(vandq_s32(in0_4, in0_right_vec), 31));
|
||||
in1_1 = vqaddq_s32(in1_1, vshrq_n_s32(vandq_s32(in1_1, in1_right_vec), 31));
|
||||
in1_2 = vqaddq_s32(in1_2, vshrq_n_s32(vandq_s32(in1_2, in1_right_vec), 31));
|
||||
in1_3 = vqaddq_s32(in1_3, vshrq_n_s32(vandq_s32(in1_3, in1_right_vec), 31));
|
||||
in1_4 = vqaddq_s32(in1_4, vshrq_n_s32(vandq_s32(in1_4, in1_right_vec), 31));
|
||||
ptr1 = vqaddq_s32(ptr1, vshrq_n_s32(vandq_s32(ptr1, ptr_right_vec), 31));
|
||||
ptr2 = vqaddq_s32(ptr2, vshrq_n_s32(vandq_s32(ptr2, ptr_right_vec), 31));
|
||||
ptr3 = vqaddq_s32(ptr3, vshrq_n_s32(vandq_s32(ptr3, ptr_right_vec), 31));
|
||||
ptr4 = vqaddq_s32(ptr4, vshrq_n_s32(vandq_s32(ptr4, ptr_right_vec), 31));
|
||||
|
||||
in0_1 = vrshlq_s32(in0_1, in0_right_vec);
|
||||
in0_2 = vrshlq_s32(in0_2, in0_right_vec);
|
||||
in0_3 = vrshlq_s32(in0_3, in0_right_vec);
|
||||
in0_4 = vrshlq_s32(in0_4, in0_right_vec);
|
||||
in1_1 = vrshlq_s32(in1_1, in1_right_vec);
|
||||
in1_2 = vrshlq_s32(in1_2, in1_right_vec);
|
||||
in1_3 = vrshlq_s32(in1_3, in1_right_vec);
|
||||
in1_4 = vrshlq_s32(in1_4, in1_right_vec);
|
||||
ptr1 = vrshlq_s32(ptr1, ptr_right_vec);
|
||||
ptr2 = vrshlq_s32(ptr2, ptr_right_vec);
|
||||
ptr3 = vrshlq_s32(ptr3, ptr_right_vec);
|
||||
ptr4 = vrshlq_s32(ptr4, ptr_right_vec);
|
||||
|
||||
/* calculate output */
|
||||
int32x4_t out1 = vaddq_s32(in0_1, in1_1);
|
||||
int32x4_t out2 = vaddq_s32(in0_2, in1_2);
|
||||
int32x4_t out3 = vaddq_s32(in0_3, in1_3);
|
||||
int32x4_t out4 = vaddq_s32(in0_4, in1_4);
|
||||
int32x4_t out1 = vaddq_s32(ptr1, ele1);
|
||||
int32x4_t out2 = vaddq_s32(ptr2, ele2);
|
||||
int32x4_t out3 = vaddq_s32(ptr3, ele3);
|
||||
int32x4_t out4 = vaddq_s32(ptr4, ele4);
|
||||
|
||||
// Apply left shift
|
||||
// Apply output left shift
|
||||
out1 = vshlq_s32(out1, out_left_vec);
|
||||
out2 = vshlq_s32(out2, out_left_vec);
|
||||
out3 = vshlq_s32(out3, out_left_vec);
|
||||
out4 = vshlq_s32(out4, out_left_vec);
|
||||
|
||||
// Apply the fixed-point part of the multiplier.
|
||||
// Apply output fixed-point part of the multiplier.
|
||||
out1 = vqrdmulhq_n_s32(out1, params->out_multiplier_);
|
||||
out2 = vqrdmulhq_n_s32(out2, params->out_multiplier_);
|
||||
out3 = vqrdmulhq_n_s32(out3, params->out_multiplier_);
|
||||
out4 = vqrdmulhq_n_s32(out4, params->out_multiplier_);
|
||||
|
||||
// Apply right shift
|
||||
// Apply output right shift
|
||||
out1 = vqaddq_s32(out1, vshrq_n_s32(vandq_s32(out1, out_right_vec), 31));
|
||||
out2 = vqaddq_s32(out2, vshrq_n_s32(vandq_s32(out2, out_right_vec), 31));
|
||||
out3 = vqaddq_s32(out3, vshrq_n_s32(vandq_s32(out3, out_right_vec), 31));
|
||||
|
@ -292,12 +300,12 @@ void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, i
|
|||
#endif
|
||||
|
||||
for (; index < size; index++) {
|
||||
const int32_t in0_left = (ptr_in[index] + params->in0_zp_) * in0_left_shift;
|
||||
const int32_t in1_left = (element_in + params->in1_zp_) * in1_left_shift;
|
||||
const int32_t in0 = MultiplyByMultiplierAndRightShift(in0_left, params->in0_multiplier_, params->in0_right_shift_);
|
||||
const int32_t in1 = MultiplyByMultiplierAndRightShift(in1_left, params->in1_multiplier_, params->in1_right_shift_);
|
||||
const int32_t ptr_left = (ptr_in[index] + ptr_args->zp_) * ptr_left_shift;
|
||||
const int32_t ele_left = (element_in + ele_args->zp_) * ele_left_shift;
|
||||
const int32_t ptr = MultiplyByMultiplierAndRightShift(ptr_left, ptr_args->multiplier_, ptr_args->right_shift_);
|
||||
const int32_t ele = MultiplyByMultiplierAndRightShift(ele_left, ele_args->multiplier_, ele_args->right_shift_);
|
||||
|
||||
int32_t out = MultiplyByQuantizedMultiplier(in0 + in1, params->out_multiplier_, params->out_left_shift_,
|
||||
int32_t out = MultiplyByQuantizedMultiplier(ptr + ele, params->out_multiplier_, params->out_left_shift_,
|
||||
-params->out_right_shift_);
|
||||
out += params->out_zp_;
|
||||
output[index] = (int8_t)MSMAX(params->min_, MSMIN(out, params->max_));
|
||||
|
|
|
@ -19,23 +19,22 @@
|
|||
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
typedef struct AddQuantQrgs {
|
||||
int32_t zp_;
|
||||
int32_t left_shift_;
|
||||
int32_t right_shift_;
|
||||
int32_t multiplier_;
|
||||
} AddQuantQrgs;
|
||||
|
||||
typedef struct AddQuantParameter {
|
||||
int left_shift_;
|
||||
int32_t min_;
|
||||
int32_t max_;
|
||||
|
||||
int32_t in0_zp_;
|
||||
int32_t in1_zp_;
|
||||
AddQuantQrgs in0_args_;
|
||||
AddQuantQrgs in1_args_;
|
||||
|
||||
int32_t out_zp_;
|
||||
|
||||
int32_t in0_left_shift_;
|
||||
int32_t in0_right_shift_;
|
||||
int32_t in0_multiplier_;
|
||||
|
||||
int32_t in1_left_shift_;
|
||||
int32_t in1_right_shift_;
|
||||
int32_t in1_multiplier_;
|
||||
|
||||
int32_t out_left_shift_;
|
||||
int32_t out_right_shift_;
|
||||
int32_t out_multiplier_;
|
||||
|
@ -46,7 +45,8 @@ extern "C" {
|
|||
#endif
|
||||
|
||||
void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int size, AddQuantParameter *params);
|
||||
void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, AddQuantParameter *params);
|
||||
void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, AddQuantParameter *params,
|
||||
AddQuantQrgs *ptr_args, AddQuantQrgs *ele_args);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -35,8 +35,8 @@ int QuantizedAddCPUKernel::Init() {
|
|||
auto *input1 = in_tensors_.at(1);
|
||||
auto *output = out_tensors_.at(0);
|
||||
|
||||
para_.in0_zp_ = input0->quant_params().front().zeroPoint * -1;
|
||||
para_.in1_zp_ = input1->quant_params().front().zeroPoint * -1;
|
||||
para_.in0_args_.zp_ = input0->quant_params().front().zeroPoint * -1;
|
||||
para_.in1_args_.zp_ = input1->quant_params().front().zeroPoint * -1;
|
||||
para_.out_zp_ = output->quant_params().front().zeroPoint;
|
||||
|
||||
const double in0_scale = input0->quant_params().front().scale;
|
||||
|
@ -49,16 +49,16 @@ int QuantizedAddCPUKernel::Init() {
|
|||
const double in1_multiplier = in1_scale / twice_max_input_scale;
|
||||
const double out_multiplier = twice_max_input_scale / ((1 << para_.left_shift_) * out_scale);
|
||||
|
||||
QuantizeMultiplierSmallerThanOne(in0_multiplier, ¶_.in0_multiplier_, ¶_.in0_left_shift_);
|
||||
QuantizeMultiplierSmallerThanOne(in1_multiplier, ¶_.in1_multiplier_, ¶_.in1_left_shift_);
|
||||
QuantizeMultiplierSmallerThanOne(in0_multiplier, ¶_.in0_args_.multiplier_, ¶_.in0_args_.left_shift_);
|
||||
QuantizeMultiplierSmallerThanOne(in1_multiplier, ¶_.in1_args_.multiplier_, ¶_.in1_args_.left_shift_);
|
||||
QuantizeMultiplierSmallerThanOne(out_multiplier, ¶_.out_multiplier_, ¶_.out_left_shift_);
|
||||
|
||||
para_.in0_right_shift_ = -para_.in0_left_shift_ > 0 ? 0 : para_.in0_left_shift_;
|
||||
para_.in1_right_shift_ = -para_.in1_left_shift_ > 0 ? 0 : para_.in1_left_shift_;
|
||||
para_.in0_args_.right_shift_ = -para_.in0_args_.left_shift_ > 0 ? 0 : para_.in0_args_.left_shift_;
|
||||
para_.in1_args_.right_shift_ = -para_.in1_args_.left_shift_ > 0 ? 0 : para_.in1_args_.left_shift_;
|
||||
para_.out_right_shift_ = -para_.out_left_shift_ > 0 ? 0 : para_.out_left_shift_;
|
||||
|
||||
para_.in0_left_shift_ = -para_.in0_left_shift_ > 0 ? -para_.in0_left_shift_ : 0;
|
||||
para_.in1_left_shift_ = -para_.in1_left_shift_ > 0 ? -para_.in1_left_shift_ : 0;
|
||||
para_.in0_args_.left_shift_ = -para_.in0_args_.left_shift_ > 0 ? -para_.in0_args_.left_shift_ : 0;
|
||||
para_.in1_args_.left_shift_ = -para_.in1_args_.left_shift_ > 0 ? -para_.in1_args_.left_shift_ : 0;
|
||||
para_.out_left_shift_ = -para_.out_left_shift_ > 0 ? -para_.out_left_shift_ : 0;
|
||||
|
||||
auto act = arith_para_->activation_type_;
|
||||
|
@ -87,9 +87,24 @@ int QuantizedAddCPUKernel::ReSize() {
|
|||
arith_para_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
|
||||
arith_para_->out_elements_num_ = out_tensors_[0]->ElementsNum();
|
||||
|
||||
memcpy(arith_para_->in_shape0_, input0->shape().data(), input0->shape().size() * sizeof(int));
|
||||
memcpy(arith_para_->in_shape1_, input1->shape().data(), input1->shape().size() * sizeof(int));
|
||||
memcpy(arith_para_->out_shape_, output->shape().data(), output->shape().size() * sizeof(int));
|
||||
for (size_t i = 0; i < in_tensors_[0]->shape().size(); i++) {
|
||||
if (arith_para_->in_shape0_[i] == -1) {
|
||||
memcpy(arith_para_->in_shape0_, input0->shape().data(), input0->shape().size() * sizeof(int));
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < in_tensors_[1]->shape().size(); i++) {
|
||||
if (arith_para_->in_shape1_[i] == -1) {
|
||||
memcpy(arith_para_->in_shape1_, input1->shape().data(), input1->shape().size() * sizeof(int));
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < out_tensors_[0]->shape().size(); i++) {
|
||||
if (arith_para_->out_shape_[i] == -1) {
|
||||
memcpy(arith_para_->out_shape_, output->shape().data(), output->shape().size() * sizeof(int));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (arith_para_->broadcasting_) {
|
||||
size_t break_pos_ = 0;
|
||||
|
@ -128,14 +143,18 @@ void QuantizedAddCPUKernel::BroadcastRun(int task_id) {
|
|||
if (real_out_count <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int8_t *const_in = arith_para_->in_elements_num0_ == arith_para_->out_elements_num_ ? input1_data_ : input0_data_;
|
||||
int8_t *offset_in = arith_para_->in_elements_num0_ == arith_para_->out_elements_num_ ? input0_data_ : input1_data_;
|
||||
offset_in += task_id * stride * in_size_;
|
||||
int8_t *cur_out = output_data_ + task_id * stride * in_size_;
|
||||
|
||||
int8_t *cur_in0, *cur_in1, *cur_out;
|
||||
for (int i = 0; i < real_out_count; i++) {
|
||||
AddInt8(offset_in + i * in_size_, const_in, cur_out + i * in_size_, in_size_, ¶_);
|
||||
if (arith_para_->in_elements_num0_ == arith_para_->out_elements_num_) {
|
||||
cur_in0 = input0_data_ + task_id * stride * in_size_ + i * in_size_;
|
||||
cur_in1 = input1_data_;
|
||||
cur_out = output_data_ + task_id * stride * in_size_ + i * in_size_;
|
||||
} else {
|
||||
cur_in0 = input0_data_;
|
||||
cur_in1 = input1_data_ + task_id * stride * in_size_ + i * in_size_;
|
||||
cur_out = output_data_ + task_id * stride * in_size_ + i * in_size_;
|
||||
}
|
||||
AddInt8(cur_in0, cur_in1, cur_out, in_size_, ¶_);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
@ -160,7 +179,9 @@ int QuantizedAddCPUKernel::DoExecute(int task_id) {
|
|||
if (support_opt_add_) {
|
||||
int8_t *ptr_in = arith_para_->in_elements_num0_ == 1 ? cur_in1 : cur_in0;
|
||||
int8_t element_in = arith_para_->in_elements_num0_ == 1 ? input0_data_[0] : input1_data_[0];
|
||||
AddOptInt8(ptr_in, element_in, cur_out, rest_count, ¶_);
|
||||
AddQuantQrgs *ptr_args = arith_para_->in_elements_num0_ == 1 ? ¶_.in1_args_ : ¶_.in0_args_;
|
||||
AddQuantQrgs *ele_args = arith_para_->in_elements_num0_ == 1 ? ¶_.in0_args_ : ¶_.in1_args_;
|
||||
AddOptInt8(ptr_in, element_in, cur_out, rest_count, ¶_, ptr_args, ele_args);
|
||||
} else {
|
||||
AddInt8(cur_in0, cur_in1, cur_out, rest_count, ¶_);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue