!9182 [MSLITE] add const node bug

From: @ling_qiao_min
Reviewed-by: @zhanghaibo5,@zhang_xue_tong
Signed-off-by: @zhang_xue_tong
This commit is contained in:
mindspore-ci-bot 2020-12-02 11:30:57 +08:00 committed by Gitee
commit 5c5e1520cd
3 changed files with 150 additions and 121 deletions

View File

@ -21,23 +21,23 @@
#include "nnacl/quantization/fixed_point.h"
void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int size, AddQuantParameter *params) {
int in0_left_shift = (1 << params->left_shift_) * (1 << params->in0_left_shift_);
int in1_left_shift = (1 << params->left_shift_) * (1 << params->in1_left_shift_);
int in0_left_shift = (1 << params->left_shift_) * (1 << params->in0_args_.left_shift_);
int in1_left_shift = (1 << params->left_shift_) * (1 << params->in1_args_.left_shift_);
int index = 0;
#ifdef ENABLE_ARM
const int8x16_t min_vec = vdupq_n_s8(params->min_);
const int8x16_t max_vac = vdupq_n_s8(params->max_);
const int16x8_t in0_zp_vec = vdupq_n_s16(params->in0_zp_);
const int16x8_t in1_zp_vec = vdupq_n_s16(params->in1_zp_);
const int16x8_t in0_zp_vec = vdupq_n_s16(params->in0_args_.zp_);
const int16x8_t in1_zp_vec = vdupq_n_s16(params->in1_args_.zp_);
const int16x8_t out_zp_vec = vdupq_n_s16(params->out_zp_);
const int32x4_t in0_left_vec = vdupq_n_s32(in0_left_shift);
const int32x4_t in1_left_vec = vdupq_n_s32(in1_left_shift);
const int32x4_t in0_right_vec = vdupq_n_s32(-params->in0_right_shift_);
const int32x4_t in1_right_vec = vdupq_n_s32(-params->in1_right_shift_);
const int32x4_t in0_right_vec = vdupq_n_s32(-params->in0_args_.right_shift_);
const int32x4_t in1_right_vec = vdupq_n_s32(-params->in1_args_.right_shift_);
const int32x4_t out_left_vec = vdupq_n_s32(params->out_left_shift_);
const int32x4_t out_right_vec = vdupq_n_s32(-params->out_right_shift_);
@ -76,14 +76,14 @@ void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int siz
in1_4 = vmulq_s32(in1_4, in1_left_vec);
// Apply the fixed-point part of the multiplier.
in0_1 = vqrdmulhq_n_s32(in0_1, params->in0_multiplier_);
in0_2 = vqrdmulhq_n_s32(in0_2, params->in0_multiplier_);
in0_3 = vqrdmulhq_n_s32(in0_3, params->in0_multiplier_);
in0_4 = vqrdmulhq_n_s32(in0_4, params->in0_multiplier_);
in1_1 = vqrdmulhq_n_s32(in1_1, params->in1_multiplier_);
in1_2 = vqrdmulhq_n_s32(in1_2, params->in1_multiplier_);
in1_3 = vqrdmulhq_n_s32(in1_3, params->in1_multiplier_);
in1_4 = vqrdmulhq_n_s32(in1_4, params->in1_multiplier_);
in0_1 = vqrdmulhq_n_s32(in0_1, params->in0_args_.multiplier_);
in0_2 = vqrdmulhq_n_s32(in0_2, params->in0_args_.multiplier_);
in0_3 = vqrdmulhq_n_s32(in0_3, params->in0_args_.multiplier_);
in0_4 = vqrdmulhq_n_s32(in0_4, params->in0_args_.multiplier_);
in1_1 = vqrdmulhq_n_s32(in1_1, params->in1_args_.multiplier_);
in1_2 = vqrdmulhq_n_s32(in1_2, params->in1_args_.multiplier_);
in1_3 = vqrdmulhq_n_s32(in1_3, params->in1_args_.multiplier_);
in1_4 = vqrdmulhq_n_s32(in1_4, params->in1_args_.multiplier_);
// Apply right shift
in0_1 = vqaddq_s32(in0_1, vshrq_n_s32(vandq_s32(in0_1, in0_right_vec), 31));
@ -149,10 +149,12 @@ void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int siz
#endif
for (; index < size; index++) {
const int32_t in0_left = (input0[index] + params->in0_zp_) * in0_left_shift;
const int32_t in1_left = (input1[index] + params->in1_zp_) * in1_left_shift;
const int32_t in0 = MultiplyByMultiplierAndRightShift(in0_left, params->in0_multiplier_, params->in0_right_shift_);
const int32_t in1 = MultiplyByMultiplierAndRightShift(in1_left, params->in1_multiplier_, params->in1_right_shift_);
const int32_t in0_left = (input0[index] + params->in0_args_.zp_) * in0_left_shift;
const int32_t in1_left = (input1[index] + params->in1_args_.zp_) * in1_left_shift;
const int32_t in0 =
MultiplyByMultiplierAndRightShift(in0_left, params->in0_args_.multiplier_, params->in0_args_.right_shift_);
const int32_t in1 =
MultiplyByMultiplierAndRightShift(in1_left, params->in1_args_.multiplier_, params->in1_args_.right_shift_);
int32_t out = MultiplyByQuantizedMultiplier(in0 + in1, params->out_multiplier_, params->out_left_shift_,
-params->out_right_shift_);
@ -162,110 +164,116 @@ void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int siz
return;
}
void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, AddQuantParameter *params) {
int in0_left_shift = (1 << params->left_shift_) * (1 << params->in0_left_shift_);
int in1_left_shift = (1 << params->left_shift_) * (1 << params->in1_left_shift_);
void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, AddQuantParameter *params,
AddQuantQrgs *ptr_args, AddQuantQrgs *ele_args) {
int ptr_left_shift = (1 << params->left_shift_) * (1 << ptr_args->left_shift_);
int ele_left_shift = (1 << params->left_shift_) * (1 << ele_args->left_shift_);
int index = 0;
#ifdef ENABLE_ARM
const int8x16_t in1_src = vdupq_n_s8(element_in);
/* const value init */
const int8x16_t min_vec = vdupq_n_s8(params->min_);
const int8x16_t max_vac = vdupq_n_s8(params->max_);
const int16x8_t in0_zp_vec = vdupq_n_s16(params->in0_zp_);
const int16x8_t in1_zp_vec = vdupq_n_s16(params->in1_zp_);
const int16x8_t ptr_zp_vec = vdupq_n_s16(ptr_args->zp_);
const int16x8_t ele_zp_vec = vdupq_n_s16(ele_args->zp_);
const int16x8_t out_zp_vec = vdupq_n_s16(params->out_zp_);
const int32x4_t in0_left_vec = vdupq_n_s32(in0_left_shift);
const int32x4_t in1_left_vec = vdupq_n_s32(in1_left_shift);
const int32x4_t ptr_left_vec = vdupq_n_s32(ptr_left_shift);
const int32x4_t ele_left_vec = vdupq_n_s32(ele_left_shift);
const int32x4_t in0_right_vec = vdupq_n_s32(-params->in0_right_shift_);
const int32x4_t in1_right_vec = vdupq_n_s32(-params->in1_right_shift_);
const int32x4_t ptr_right_vec = vdupq_n_s32(-ptr_args->right_shift_);
const int32x4_t ele_right_vec = vdupq_n_s32(-ptr_args->right_shift_);
const int32x4_t out_left_vec = vdupq_n_s32(params->out_left_shift_);
const int32x4_t out_right_vec = vdupq_n_s32(-params->out_right_shift_);
/* deal with const node */
const int8x16_t ele_src = vdupq_n_s8(element_in);
const int16x8_t ele_s16_low = vmovl_s8(vget_low_s8(ele_src));
const int16x8_t ele_s16_high = vmovl_s8(vget_high_s8(ele_src));
const int16x8_t ele_zp_low = vaddq_s16(ele_s16_low, ele_zp_vec);
const int16x8_t ele_zp_high = vaddq_s16(ele_s16_high, ele_zp_vec);
int32x4_t ele1 = vmovl_s16(vget_low_s16(ele_zp_low));
int32x4_t ele2 = vmovl_s16(vget_high_s16(ele_zp_low));
int32x4_t ele3 = vmovl_s16(vget_low_s16(ele_zp_high));
int32x4_t ele4 = vmovl_s16(vget_high_s16(ele_zp_high));
// Apply left shift
ele1 = vmulq_s32(ele1, ele_left_vec);
ele2 = vmulq_s32(ele2, ele_left_vec);
ele3 = vmulq_s32(ele3, ele_left_vec);
ele4 = vmulq_s32(ele4, ele_left_vec);
// Apply the fixed-point part of the multiplier.
ele1 = vqrdmulhq_n_s32(ele1, ele_args->multiplier_);
ele2 = vqrdmulhq_n_s32(ele2, ele_args->multiplier_);
ele3 = vqrdmulhq_n_s32(ele3, ele_args->multiplier_);
ele4 = vqrdmulhq_n_s32(ele4, ele_args->multiplier_);
// Apply right shift
ele1 = vqaddq_s32(ele1, vshrq_n_s32(vandq_s32(ele1, ele_right_vec), 31));
ele2 = vqaddq_s32(ele2, vshrq_n_s32(vandq_s32(ele2, ele_right_vec), 31));
ele3 = vqaddq_s32(ele3, vshrq_n_s32(vandq_s32(ele3, ele_right_vec), 31));
ele4 = vqaddq_s32(ele4, vshrq_n_s32(vandq_s32(ele4, ele_right_vec), 31));
ele1 = vrshlq_s32(ele1, ele_right_vec);
ele2 = vrshlq_s32(ele2, ele_right_vec);
ele3 = vrshlq_s32(ele3, ele_right_vec);
ele4 = vrshlq_s32(ele4, ele_right_vec);
for (; index <= size - 16; index += 16) {
const int8x16_t in0_src = vld1q_s8(ptr_in + index);
const int8x16_t ptr_src = vld1q_s8(ptr_in + index);
const int16x8_t in0_s16_low = vmovl_s8(vget_low_s8(in0_src));
const int16x8_t in0_s16_high = vmovl_s8(vget_high_s8(in0_src));
const int16x8_t in1_s16_low = vmovl_s8(vget_low_s8(in1_src));
const int16x8_t in1_s16_high = vmovl_s8(vget_high_s8(in1_src));
const int16x8_t ptr_s16_low = vmovl_s8(vget_low_s8(ptr_src));
const int16x8_t ptr_s16_high = vmovl_s8(vget_high_s8(ptr_src));
const int16x8_t in0_zp_low = vaddq_s16(in0_s16_low, in0_zp_vec);
const int16x8_t in0_zp_high = vaddq_s16(in0_s16_high, in0_zp_vec);
const int16x8_t in1_zp_low = vaddq_s16(in1_s16_low, in1_zp_vec);
const int16x8_t in1_zp_high = vaddq_s16(in1_s16_high, in1_zp_vec);
const int16x8_t ptr_zp_low = vaddq_s16(ptr_s16_low, ptr_zp_vec);
const int16x8_t ptr_zp_high = vaddq_s16(ptr_s16_high, ptr_zp_vec);
int32x4_t in0_1 = vmovl_s16(vget_low_s16(in0_zp_low));
int32x4_t in0_2 = vmovl_s16(vget_high_s16(in0_zp_low));
int32x4_t in0_3 = vmovl_s16(vget_low_s16(in0_zp_high));
int32x4_t in0_4 = vmovl_s16(vget_high_s16(in0_zp_high));
int32x4_t in1_1 = vmovl_s16(vget_low_s16(in1_zp_low));
int32x4_t in1_2 = vmovl_s16(vget_high_s16(in1_zp_low));
int32x4_t in1_3 = vmovl_s16(vget_low_s16(in1_zp_high));
int32x4_t in1_4 = vmovl_s16(vget_high_s16(in1_zp_high));
int32x4_t ptr1 = vmovl_s16(vget_low_s16(ptr_zp_low));
int32x4_t ptr2 = vmovl_s16(vget_high_s16(ptr_zp_low));
int32x4_t ptr3 = vmovl_s16(vget_low_s16(ptr_zp_high));
int32x4_t ptr4 = vmovl_s16(vget_high_s16(ptr_zp_high));
// Apply left shift
in0_1 = vmulq_s32(in0_1, in0_left_vec);
in0_2 = vmulq_s32(in0_2, in0_left_vec);
in0_3 = vmulq_s32(in0_3, in0_left_vec);
in0_4 = vmulq_s32(in0_4, in0_left_vec);
in1_1 = vmulq_s32(in1_1, in1_left_vec);
in1_2 = vmulq_s32(in1_2, in1_left_vec);
in1_3 = vmulq_s32(in1_3, in1_left_vec);
in1_4 = vmulq_s32(in1_4, in1_left_vec);
ptr1 = vmulq_s32(ptr1, ptr_left_vec);
ptr2 = vmulq_s32(ptr2, ptr_left_vec);
ptr3 = vmulq_s32(ptr3, ptr_left_vec);
ptr4 = vmulq_s32(ptr4, ptr_left_vec);
// Apply the fixed-point part of the multiplier.
in0_1 = vqrdmulhq_n_s32(in0_1, params->in0_multiplier_);
in0_2 = vqrdmulhq_n_s32(in0_2, params->in0_multiplier_);
in0_3 = vqrdmulhq_n_s32(in0_3, params->in0_multiplier_);
in0_4 = vqrdmulhq_n_s32(in0_4, params->in0_multiplier_);
in1_1 = vqrdmulhq_n_s32(in1_1, params->in1_multiplier_);
in1_2 = vqrdmulhq_n_s32(in1_2, params->in1_multiplier_);
in1_3 = vqrdmulhq_n_s32(in1_3, params->in1_multiplier_);
in1_4 = vqrdmulhq_n_s32(in1_4, params->in1_multiplier_);
ptr1 = vqrdmulhq_n_s32(ptr1, ptr_args->multiplier_);
ptr2 = vqrdmulhq_n_s32(ptr2, ptr_args->multiplier_);
ptr3 = vqrdmulhq_n_s32(ptr3, ptr_args->multiplier_);
ptr4 = vqrdmulhq_n_s32(ptr4, ptr_args->multiplier_);
// Apply right shift
in0_1 = vqaddq_s32(in0_1, vshrq_n_s32(vandq_s32(in0_1, in0_right_vec), 31));
in0_2 = vqaddq_s32(in0_2, vshrq_n_s32(vandq_s32(in0_2, in0_right_vec), 31));
in0_3 = vqaddq_s32(in0_3, vshrq_n_s32(vandq_s32(in0_3, in0_right_vec), 31));
in0_4 = vqaddq_s32(in0_4, vshrq_n_s32(vandq_s32(in0_4, in0_right_vec), 31));
in1_1 = vqaddq_s32(in1_1, vshrq_n_s32(vandq_s32(in1_1, in1_right_vec), 31));
in1_2 = vqaddq_s32(in1_2, vshrq_n_s32(vandq_s32(in1_2, in1_right_vec), 31));
in1_3 = vqaddq_s32(in1_3, vshrq_n_s32(vandq_s32(in1_3, in1_right_vec), 31));
in1_4 = vqaddq_s32(in1_4, vshrq_n_s32(vandq_s32(in1_4, in1_right_vec), 31));
ptr1 = vqaddq_s32(ptr1, vshrq_n_s32(vandq_s32(ptr1, ptr_right_vec), 31));
ptr2 = vqaddq_s32(ptr2, vshrq_n_s32(vandq_s32(ptr2, ptr_right_vec), 31));
ptr3 = vqaddq_s32(ptr3, vshrq_n_s32(vandq_s32(ptr3, ptr_right_vec), 31));
ptr4 = vqaddq_s32(ptr4, vshrq_n_s32(vandq_s32(ptr4, ptr_right_vec), 31));
in0_1 = vrshlq_s32(in0_1, in0_right_vec);
in0_2 = vrshlq_s32(in0_2, in0_right_vec);
in0_3 = vrshlq_s32(in0_3, in0_right_vec);
in0_4 = vrshlq_s32(in0_4, in0_right_vec);
in1_1 = vrshlq_s32(in1_1, in1_right_vec);
in1_2 = vrshlq_s32(in1_2, in1_right_vec);
in1_3 = vrshlq_s32(in1_3, in1_right_vec);
in1_4 = vrshlq_s32(in1_4, in1_right_vec);
ptr1 = vrshlq_s32(ptr1, ptr_right_vec);
ptr2 = vrshlq_s32(ptr2, ptr_right_vec);
ptr3 = vrshlq_s32(ptr3, ptr_right_vec);
ptr4 = vrshlq_s32(ptr4, ptr_right_vec);
/* calculate output */
int32x4_t out1 = vaddq_s32(in0_1, in1_1);
int32x4_t out2 = vaddq_s32(in0_2, in1_2);
int32x4_t out3 = vaddq_s32(in0_3, in1_3);
int32x4_t out4 = vaddq_s32(in0_4, in1_4);
int32x4_t out1 = vaddq_s32(ptr1, ele1);
int32x4_t out2 = vaddq_s32(ptr2, ele2);
int32x4_t out3 = vaddq_s32(ptr3, ele3);
int32x4_t out4 = vaddq_s32(ptr4, ele4);
// Apply left shift
// Apply output left shift
out1 = vshlq_s32(out1, out_left_vec);
out2 = vshlq_s32(out2, out_left_vec);
out3 = vshlq_s32(out3, out_left_vec);
out4 = vshlq_s32(out4, out_left_vec);
// Apply the fixed-point part of the multiplier.
// Apply output fixed-point part of the multiplier.
out1 = vqrdmulhq_n_s32(out1, params->out_multiplier_);
out2 = vqrdmulhq_n_s32(out2, params->out_multiplier_);
out3 = vqrdmulhq_n_s32(out3, params->out_multiplier_);
out4 = vqrdmulhq_n_s32(out4, params->out_multiplier_);
// Apply right shift
// Apply output right shift
out1 = vqaddq_s32(out1, vshrq_n_s32(vandq_s32(out1, out_right_vec), 31));
out2 = vqaddq_s32(out2, vshrq_n_s32(vandq_s32(out2, out_right_vec), 31));
out3 = vqaddq_s32(out3, vshrq_n_s32(vandq_s32(out3, out_right_vec), 31));
@ -292,12 +300,12 @@ void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, i
#endif
for (; index < size; index++) {
const int32_t in0_left = (ptr_in[index] + params->in0_zp_) * in0_left_shift;
const int32_t in1_left = (element_in + params->in1_zp_) * in1_left_shift;
const int32_t in0 = MultiplyByMultiplierAndRightShift(in0_left, params->in0_multiplier_, params->in0_right_shift_);
const int32_t in1 = MultiplyByMultiplierAndRightShift(in1_left, params->in1_multiplier_, params->in1_right_shift_);
const int32_t ptr_left = (ptr_in[index] + ptr_args->zp_) * ptr_left_shift;
const int32_t ele_left = (element_in + ele_args->zp_) * ele_left_shift;
const int32_t ptr = MultiplyByMultiplierAndRightShift(ptr_left, ptr_args->multiplier_, ptr_args->right_shift_);
const int32_t ele = MultiplyByMultiplierAndRightShift(ele_left, ele_args->multiplier_, ele_args->right_shift_);
int32_t out = MultiplyByQuantizedMultiplier(in0 + in1, params->out_multiplier_, params->out_left_shift_,
int32_t out = MultiplyByQuantizedMultiplier(ptr + ele, params->out_multiplier_, params->out_left_shift_,
-params->out_right_shift_);
out += params->out_zp_;
output[index] = (int8_t)MSMAX(params->min_, MSMIN(out, params->max_));

View File

@ -19,23 +19,22 @@
#include "nnacl/op_base.h"
typedef struct AddQuantQrgs {
int32_t zp_;
int32_t left_shift_;
int32_t right_shift_;
int32_t multiplier_;
} AddQuantQrgs;
typedef struct AddQuantParameter {
int left_shift_;
int32_t min_;
int32_t max_;
int32_t in0_zp_;
int32_t in1_zp_;
AddQuantQrgs in0_args_;
AddQuantQrgs in1_args_;
int32_t out_zp_;
int32_t in0_left_shift_;
int32_t in0_right_shift_;
int32_t in0_multiplier_;
int32_t in1_left_shift_;
int32_t in1_right_shift_;
int32_t in1_multiplier_;
int32_t out_left_shift_;
int32_t out_right_shift_;
int32_t out_multiplier_;
@ -46,7 +45,8 @@ extern "C" {
#endif
void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int size, AddQuantParameter *params);
void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, AddQuantParameter *params);
void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, AddQuantParameter *params,
AddQuantQrgs *ptr_args, AddQuantQrgs *ele_args);
#ifdef __cplusplus
}

View File

@ -35,8 +35,8 @@ int QuantizedAddCPUKernel::Init() {
auto *input1 = in_tensors_.at(1);
auto *output = out_tensors_.at(0);
para_.in0_zp_ = input0->quant_params().front().zeroPoint * -1;
para_.in1_zp_ = input1->quant_params().front().zeroPoint * -1;
para_.in0_args_.zp_ = input0->quant_params().front().zeroPoint * -1;
para_.in1_args_.zp_ = input1->quant_params().front().zeroPoint * -1;
para_.out_zp_ = output->quant_params().front().zeroPoint;
const double in0_scale = input0->quant_params().front().scale;
@ -49,16 +49,16 @@ int QuantizedAddCPUKernel::Init() {
const double in1_multiplier = in1_scale / twice_max_input_scale;
const double out_multiplier = twice_max_input_scale / ((1 << para_.left_shift_) * out_scale);
QuantizeMultiplierSmallerThanOne(in0_multiplier, &para_.in0_multiplier_, &para_.in0_left_shift_);
QuantizeMultiplierSmallerThanOne(in1_multiplier, &para_.in1_multiplier_, &para_.in1_left_shift_);
QuantizeMultiplierSmallerThanOne(in0_multiplier, &para_.in0_args_.multiplier_, &para_.in0_args_.left_shift_);
QuantizeMultiplierSmallerThanOne(in1_multiplier, &para_.in1_args_.multiplier_, &para_.in1_args_.left_shift_);
QuantizeMultiplierSmallerThanOne(out_multiplier, &para_.out_multiplier_, &para_.out_left_shift_);
para_.in0_right_shift_ = -para_.in0_left_shift_ > 0 ? 0 : para_.in0_left_shift_;
para_.in1_right_shift_ = -para_.in1_left_shift_ > 0 ? 0 : para_.in1_left_shift_;
para_.in0_args_.right_shift_ = -para_.in0_args_.left_shift_ > 0 ? 0 : para_.in0_args_.left_shift_;
para_.in1_args_.right_shift_ = -para_.in1_args_.left_shift_ > 0 ? 0 : para_.in1_args_.left_shift_;
para_.out_right_shift_ = -para_.out_left_shift_ > 0 ? 0 : para_.out_left_shift_;
para_.in0_left_shift_ = -para_.in0_left_shift_ > 0 ? -para_.in0_left_shift_ : 0;
para_.in1_left_shift_ = -para_.in1_left_shift_ > 0 ? -para_.in1_left_shift_ : 0;
para_.in0_args_.left_shift_ = -para_.in0_args_.left_shift_ > 0 ? -para_.in0_args_.left_shift_ : 0;
para_.in1_args_.left_shift_ = -para_.in1_args_.left_shift_ > 0 ? -para_.in1_args_.left_shift_ : 0;
para_.out_left_shift_ = -para_.out_left_shift_ > 0 ? -para_.out_left_shift_ : 0;
auto act = arith_para_->activation_type_;
@ -87,9 +87,24 @@ int QuantizedAddCPUKernel::ReSize() {
arith_para_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
arith_para_->out_elements_num_ = out_tensors_[0]->ElementsNum();
memcpy(arith_para_->in_shape0_, input0->shape().data(), input0->shape().size() * sizeof(int));
memcpy(arith_para_->in_shape1_, input1->shape().data(), input1->shape().size() * sizeof(int));
memcpy(arith_para_->out_shape_, output->shape().data(), output->shape().size() * sizeof(int));
for (size_t i = 0; i < in_tensors_[0]->shape().size(); i++) {
if (arith_para_->in_shape0_[i] == -1) {
memcpy(arith_para_->in_shape0_, input0->shape().data(), input0->shape().size() * sizeof(int));
break;
}
}
for (size_t i = 0; i < in_tensors_[1]->shape().size(); i++) {
if (arith_para_->in_shape1_[i] == -1) {
memcpy(arith_para_->in_shape1_, input1->shape().data(), input1->shape().size() * sizeof(int));
break;
}
}
for (size_t i = 0; i < out_tensors_[0]->shape().size(); i++) {
if (arith_para_->out_shape_[i] == -1) {
memcpy(arith_para_->out_shape_, output->shape().data(), output->shape().size() * sizeof(int));
break;
}
}
if (arith_para_->broadcasting_) {
size_t break_pos_ = 0;
@ -128,14 +143,18 @@ void QuantizedAddCPUKernel::BroadcastRun(int task_id) {
if (real_out_count <= 0) {
return;
}
int8_t *const_in = arith_para_->in_elements_num0_ == arith_para_->out_elements_num_ ? input1_data_ : input0_data_;
int8_t *offset_in = arith_para_->in_elements_num0_ == arith_para_->out_elements_num_ ? input0_data_ : input1_data_;
offset_in += task_id * stride * in_size_;
int8_t *cur_out = output_data_ + task_id * stride * in_size_;
int8_t *cur_in0, *cur_in1, *cur_out;
for (int i = 0; i < real_out_count; i++) {
AddInt8(offset_in + i * in_size_, const_in, cur_out + i * in_size_, in_size_, &para_);
if (arith_para_->in_elements_num0_ == arith_para_->out_elements_num_) {
cur_in0 = input0_data_ + task_id * stride * in_size_ + i * in_size_;
cur_in1 = input1_data_;
cur_out = output_data_ + task_id * stride * in_size_ + i * in_size_;
} else {
cur_in0 = input0_data_;
cur_in1 = input1_data_ + task_id * stride * in_size_ + i * in_size_;
cur_out = output_data_ + task_id * stride * in_size_ + i * in_size_;
}
AddInt8(cur_in0, cur_in1, cur_out, in_size_, &para_);
}
return;
}
@ -160,7 +179,9 @@ int QuantizedAddCPUKernel::DoExecute(int task_id) {
if (support_opt_add_) {
int8_t *ptr_in = arith_para_->in_elements_num0_ == 1 ? cur_in1 : cur_in0;
int8_t element_in = arith_para_->in_elements_num0_ == 1 ? input0_data_[0] : input1_data_[0];
AddOptInt8(ptr_in, element_in, cur_out, rest_count, &para_);
AddQuantQrgs *ptr_args = arith_para_->in_elements_num0_ == 1 ? &para_.in1_args_ : &para_.in0_args_;
AddQuantQrgs *ele_args = arith_para_->in_elements_num0_ == 1 ? &para_.in0_args_ : &para_.in1_args_;
AddOptInt8(ptr_in, element_in, cur_out, rest_count, &para_, ptr_args, ele_args);
} else {
AddInt8(cur_in0, cur_in1, cur_out, rest_count, &para_);
}