!8148 [MS][LITE][CPU]optimize int8 scale op

Merge pull request !8148 from fuzhiye/tmp
This commit is contained in:
mindspore-ci-bot 2020-11-03 15:53:01 +08:00 committed by Gitee
commit c47d0d68fe
6 changed files with 331 additions and 165 deletions

View File

@ -53,6 +53,8 @@ void ComputeStrides(const int *shape, int *strides, const int ndim);
void CalcMultiplesAndStrides(ArithmeticParameter *param);
void TileOneDimensionUint8(uint8_t *inData, uint8_t *outData, int dim, size_t ndim, int *inShape, int *inStrides,
int *outStrides, int *multiple);
void TileDimensions(float *data0, float *data1, float *tile_data0, float *tile_data1, ArithmeticParameter *param);
void TileDimensionsUint8(uint8_t *data0, uint8_t *data1, uint8_t *tile_data0, uint8_t *tile_data1,
ArithmeticParameter *param);

View File

@ -17,78 +17,148 @@
#include "nnacl/int8/scale_int8.h"
#include "nnacl/quantization/fixed_point.h"
void ScaleInnerInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, int outer_start, int outer_end,
int axis_size, int inner_size, const ScaleParameter *scale_param, int max, int min) {
for (int out = outer_start; out < outer_end; out++) {
int out_offset = out * axis_size * inner_size;
for (int i = 0; i < axis_size; i++) {
int axis_offset = out_offset + i * inner_size;
int in_index = 0;
#ifdef ENABLE_NEON
int16x4_t ClacSumHalfWordMul2(int32x4_t scaled_input0, int32x4_t scaled_input1, int32x4_t left_shift_out_vec,
int32x4_t output_multiplier_vec, const ScaleParameter *scale_param) {
int32x4_t input_scale = vmulq_s32(scaled_input0, scaled_input1);
int32x4_t raw_sum = RoundingDivideByPOTInt32x4(
SaturatingRoundingDoublingHighMulInt32x4(vmulq_s32(input_scale, left_shift_out_vec), output_multiplier_vec),
scale_param->scale_mul_arg_.right_shift_);
raw_sum = vaddq_s32(raw_sum, vdupq_n_s32(scale_param->output_zp_));
raw_sum = vmaxq_s32(raw_sum, vdupq_n_s32(scale_param->output_activation_min_));
raw_sum = vminq_s32(raw_sum, vdupq_n_s32(scale_param->output_activation_max_));
return vqmovn_s32(raw_sum);
}
for (; in_index < inner_size; in_index++) {
int in_offset = axis_offset + in_index;
int tmp_input_scale = (in_data[in_offset] - scale_param->input_zp_) * (scale[i] - scale_param->scale_zp_);
int input_mul_scale =
RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
tmp_input_scale * (1 << (unsigned int)scale_param->scale_mul_arg_.left_shift_),
int16x4_t ClacSumHalfWordMul3(int32x4_t scaled_input0, int32x4_t scaled_input1, int32x4_t scaled_input2,
const ScaleParameter *scale_param) {
int32x4_t output_multiplier_vec = vdupq_n_s32(scale_param->scale_mul_arg_.multiplier_);
int32x4_t output_multiplier_vec2 = vdupq_n_s32(scale_param->offset_mul_arg_.multiplier_);
int32x4_t left_shift_out_vec = vdupq_n_s32(1 << scale_param->scale_mul_arg_.left_shift_);
int32x4_t left_shift_out_vec2 = vdupq_n_s32(1 << scale_param->offset_mul_arg_.left_shift_);
int32x4_t input_scale = vmulq_s32(scaled_input0, scaled_input1);
int32x4_t raw_sum = RoundingDivideByPOTInt32x4(
SaturatingRoundingDoublingHighMulInt32x4(vmulq_s32(input_scale, left_shift_out_vec), output_multiplier_vec),
scale_param->scale_mul_arg_.right_shift_);
int32x4_t raw_sum2 = RoundingDivideByPOTInt32x4(
SaturatingRoundingDoublingHighMulInt32x4(vmulq_s32(scaled_input2, left_shift_out_vec2), output_multiplier_vec2),
scale_param->offset_mul_arg_.right_shift_);
raw_sum = vaddq_s32(raw_sum, vdupq_n_s32(scale_param->output_zp_));
raw_sum = vaddq_s32(raw_sum, raw_sum2);
raw_sum = vmaxq_s32(raw_sum, vdupq_n_s32(scale_param->output_activation_min_));
raw_sum = vminq_s32(raw_sum, vdupq_n_s32(scale_param->output_activation_max_));
return vqmovn_s32(raw_sum);
}
#endif
void DoScaleInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const ScaleParameter *scale_param,
int real_dst_count) {
int index = 0;
#ifdef ENABLE_NEON
int32x4_t output_multiplier_vec = vdupq_n_s32(scale_param->scale_mul_arg_.multiplier_);
int32x4_t left_shift_out_vec = vdupq_n_s32(1 << scale_param->scale_mul_arg_.left_shift_);
for (; index <= real_dst_count - 8; index += 8) {
int8x8_t input_s8 = vld1_s8(in_data + index);
int16x8_t input_s16 = vmovl_s8(input_s8);
int16x8_t input0_val = vaddq_s16(input_s16, vdupq_n_s16(scale_param->input_zp_));
int8x8_t input1_s8 = vld1_s8(scale + index);
int16x8_t input1_s16 = vmovl_s8(input1_s8);
int16x8_t input1_val = vaddq_s16(input1_s16, vdupq_n_s16(scale_param->scale_zp_));
int32x4_t input0_low = vmovl_s16(vget_low_s16(input0_val));
int32x4_t input0_high = vmovl_s16(vget_high_s16(input0_val));
int32x4_t input1_low = vmovl_s16(vget_low_s16(input1_val));
int32x4_t input1_high = vmovl_s16(vget_high_s16(input1_val));
int16x4_t sum_low =
ClacSumHalfWordMul2(input0_low, input1_low, left_shift_out_vec, output_multiplier_vec, scale_param);
int16x4_t sum_high =
ClacSumHalfWordMul2(input0_high, input1_high, left_shift_out_vec, output_multiplier_vec, scale_param);
int16x8_t res_s16 = vcombine_s16(sum_low, sum_high);
int8x8_t res_u8_n0 = vqmovn_s16(res_s16);
vst1_s8(out_data, res_u8_n0);
out_data += 8;
}
#endif
for (; index < real_dst_count; ++index) {
const int32_t input0_val = scale_param->input_zp_ + in_data[index];
const int32_t input1_val = scale_param->scale_zp_ + scale[index];
int32_t mul_result = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(input0_val * input1_val * (1 << scale_param->scale_mul_arg_.left_shift_),
scale_param->scale_mul_arg_.multiplier_),
scale_param->scale_mul_arg_.right_shift_);
int tmp = input_mul_scale + scale_param->output_zp_;
tmp = tmp > max ? max : tmp;
tmp = tmp < min ? min : tmp;
out_data[in_offset] = tmp;
}
mul_result += scale_param->output_zp_;
if (mul_result > scale_param->output_activation_max_) {
out_data[index] = scale_param->output_activation_max_;
} else if (mul_result < scale_param->output_activation_min_) {
out_data[index] = scale_param->output_activation_min_;
} else {
out_data[index] = (int8_t)mul_result;
}
}
return;
}
void ScaleInnerWithBiasInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const int8_t *offset,
int outer_start, int outer_end, int axis_size, int inner_size,
const ScaleParameter *scale_param, int max, int min) {
for (int out = outer_start; out < outer_end; out++) {
int out_offset = out * axis_size * inner_size;
for (int i = 0; i < axis_size; i++) {
int axis_offset = out_offset + i * inner_size;
int in_index = 0;
void DoScaleWithBiasInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const int8_t *offset,
const ScaleParameter *scale_param, int real_dst_count) {
int index = 0;
#ifdef ENABLE_NEON
for (; index <= real_dst_count - 8; index += 8) {
int8x8_t input_s8 = vld1_s8(in_data + index);
int16x8_t input_s16 = vmovl_s8(input_s8);
int16x8_t input0_val = vaddq_s16(input_s16, vdupq_n_s16(scale_param->input_zp_));
for (; in_index < inner_size; in_index++) {
int in_offset = axis_offset + in_index;
int tmp_input_scale = (in_data[in_offset] - scale_param->input_zp_) * (scale[i] - scale_param->scale_zp_);
int input_mul_scale =
RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
tmp_input_scale * (1 << (unsigned int)scale_param->scale_mul_arg_.left_shift_),
int8x8_t input1_s8 = vld1_s8(scale + index);
int16x8_t input1_s16 = vmovl_s8(input1_s8);
int16x8_t input1_val = vaddq_s16(input1_s16, vdupq_n_s16(scale_param->scale_zp_));
int8x8_t input2_s8 = vld1_s8(offset + index);
int16x8_t input2_s16 = vmovl_s8(input2_s8);
int16x8_t input2_val = vaddq_s16(input2_s16, vdupq_n_s16(scale_param->offset_zp_));
int32x4_t input0_low = vmovl_s16(vget_low_s16(input0_val));
int32x4_t input0_high = vmovl_s16(vget_high_s16(input0_val));
int32x4_t input1_low = vmovl_s16(vget_low_s16(input1_val));
int32x4_t input1_high = vmovl_s16(vget_high_s16(input1_val));
int32x4_t input2_low = vmovl_s16(vget_low_s16(input2_val));
int32x4_t input2_high = vmovl_s16(vget_high_s16(input2_val));
int16x4_t sum_low = ClacSumHalfWordMul3(input0_low, input1_low, input2_low, scale_param);
int16x4_t sum_high = ClacSumHalfWordMul3(input0_high, input1_high, input2_high, scale_param);
int16x8_t res_s16 = vcombine_s16(sum_low, sum_high);
int8x8_t res_u8_n0 = vqmovn_s16(res_s16);
vst1_s8(out_data, res_u8_n0);
out_data += 8;
}
#endif
for (; index < real_dst_count; ++index) {
const int32_t input0_val = in_data[index] - scale_param->input_zp_;
const int32_t input1_val = scale[index] - scale_param->scale_zp_;
int32_t mul_result = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(input0_val * input1_val * (1 << scale_param->scale_mul_arg_.left_shift_),
scale_param->scale_mul_arg_.multiplier_),
scale_param->scale_mul_arg_.right_shift_);
int tmp_bias = offset[i] - scale_param->offset_zp_;
int tmp_bias = offset[index] - scale_param->offset_zp_;
int bias = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(tmp_bias * (1 << (unsigned int)scale_param->offset_mul_arg_.left_shift_),
scale_param->offset_mul_arg_.multiplier_),
scale_param->offset_mul_arg_.right_shift_);
int tmp = input_mul_scale + bias + scale_param->output_zp_;
tmp = tmp > max ? max : tmp;
tmp = tmp < min ? min : tmp;
out_data[in_offset] = tmp;
}
}
}
}
void DoScaleInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, int task_id,
const ScaleParameter *scale_param, int max, int min) {
int outer_step = UP_DIV(scale_param->outer_size_, scale_param->op_parameter_.thread_num_);
int outer_start = task_id * outer_step;
int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_);
mul_result += bias + scale_param->output_zp_;
ScaleInnerInt8(in_data, out_data, scale, outer_start, outer_end, scale_param->axis_size_, scale_param->inner_size_,
scale_param, max, min);
if (mul_result > scale_param->output_activation_max_) {
out_data[index] = scale_param->output_activation_max_;
} else if (mul_result < scale_param->output_activation_min_) {
out_data[index] = scale_param->output_activation_min_;
} else {
out_data[index] = (int8_t)mul_result;
}
void DoScaleWithBiasInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const int8_t *offset,
int task_id, const ScaleParameter *scale_param, int max, int min) {
int outer_step = UP_DIV(scale_param->outer_size_, scale_param->op_parameter_.thread_num_);
int outer_start = task_id * outer_step;
int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_);
ScaleInnerWithBiasInt8(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_,
scale_param->inner_size_, scale_param, max, min);
}
return;
}

View File

@ -22,10 +22,10 @@
#ifdef __cplusplus
extern "C" {
#endif
void DoScaleInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, int task_id,
const ScaleParameter *scale_param, int max, int min);
void DoScaleInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const ScaleParameter *scale_param,
int real_dst_count);
void DoScaleWithBiasInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const int8_t *offset,
int task_id, const ScaleParameter *scale_param, int max, int min);
const ScaleParameter *scale_param, int real_dst_count);
#ifdef __cplusplus
}
#endif

View File

@ -34,6 +34,8 @@ typedef struct ScaleParameter {
int offset_zp_;
int output_zp_;
int activation_type_;
int output_activation_min_;
int output_activation_max_;
} ScaleParameter;
#endif // MINDSPORE_LITE_NNACL_SCALE_H_

View File

@ -19,6 +19,7 @@
#include <string.h>
#include <vector>
#include "nnacl/int8/scale_int8.h"
#include "nnacl/arithmetic_common.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
@ -35,63 +36,65 @@ constexpr size_t kScaleInputsSize = 2;
constexpr size_t kScaleBiasInputsSize = 3;
} // namespace
ScaleInt8CPUKernel::~ScaleInt8CPUKernel() {
if (scale_param_->const_scale_) {
if (scale_ != nullptr) {
free(scale_);
scale_ = nullptr;
if (tile_para != nullptr) {
free(tile_para);
tile_para = nullptr;
}
if (input1_data_ != nullptr && malloced_scale_) {
free(input1_data_);
}
if (has_bias_ && scale_param_->const_offset_) {
if (offset_ != nullptr) {
free(offset_);
offset_ = nullptr;
}
if (input2_data_ != nullptr && malloced_offset_) {
free(input2_data_);
}
}
int ScaleInt8CPUKernel::InitScaleOffset() {
auto scale_tensor = in_tensors_.at(1);
int8_t *scale_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(1)->data_c());
CalcMultiplesAndStrides(tile_para);
scale_param_->const_scale_ = false;
auto *scale_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(1)->data_c());
// scale may be const value ,can be processed in prepare stage
if (scale_ptr != nullptr) {
scale_param_->const_scale_ = true;
if (scale_ != nullptr) {
free(scale_);
scale_ = nullptr;
}
scale_ = reinterpret_cast<int8_t *>(malloc(scale_tensor->ElementsNum() * sizeof(int8_t)));
if (scale_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
input1_data_ = scale_ptr;
// need broadcasting
if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(1)->ElementsNum()) {
input1_data_ = reinterpret_cast<int8_t *>(malloc(out_tensors_.at(0)->Size()));
if (input1_data_ == nullptr) {
MS_LOG(ERROR) << "malloc input1_data_ failed.";
return RET_ERROR;
}
memcpy(scale_, scale_ptr, scale_tensor->ElementsNum() * sizeof(int8_t));
} else {
scale_param_->const_scale_ = false;
scale_ = nullptr;
malloced_scale_ = true;
TileOneDimensionUint8(reinterpret_cast<uint8_t *>(in_tensors_.at(1)->data_c()),
reinterpret_cast<uint8_t *>(input1_data_), 0, tile_para->ndim_, tile_para->in_shape1_,
tile_para->in_strides1_, tile_para->out_strides_, tile_para->multiples1_);
}
}
scale_param_->const_offset_ = false;
if (in_tensors_.size() == 3) {
has_bias_ = true;
auto offset_tensor = in_tensors_.at(2);
int8_t *offset_ptr = reinterpret_cast<int8_t *>(offset_tensor->data_c());
auto *offset_ptr = reinterpret_cast<int8_t *>(offset_tensor->data_c());
// offset may be const value ,can be processed in prepare stage
if (offset_ptr != nullptr) {
scale_param_->const_offset_ = true;
if (offset_ != nullptr) {
free(offset_);
offset_ = nullptr;
}
offset_ = reinterpret_cast<int8_t *>(malloc(offset_tensor->ElementsNum() * sizeof(int8_t)));
if (offset_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
input2_data_ = offset_ptr;
// need broadcasting
if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(2)->ElementsNum()) {
input2_data_ = reinterpret_cast<int8_t *>(malloc(out_tensors_.at(0)->Size()));
if (input2_data_ == nullptr) {
MS_LOG(ERROR) << "malloc input2_data_ failed.";
free(input1_data_);
return RET_ERROR;
}
memcpy(offset_, offset_ptr, offset_tensor->ElementsNum() * sizeof(int8_t));
} else {
scale_param_->const_offset_ = false;
offset_ = nullptr;
malloced_offset_ = true;
TileOneDimensionUint8(reinterpret_cast<uint8_t *>(in_tensors_.at(2)->data_c()),
reinterpret_cast<uint8_t *>(input2_data_), 0, tile_para->ndim_, tile_para->in_shape1_,
tile_para->in_strides1_, tile_para->out_strides_, tile_para->multiples1_);
}
} else {
has_bias_ = false;
}
}
return RET_OK;
}
@ -102,29 +105,66 @@ int ScaleInt8CPUKernel::InitParameter() {
auto scale_shape = scale_tensor->shape();
if (scale_param_->axis_ < 0) {
scale_param_->axis_ = scale_param_->axis_ + in_shape.size();
scale_param_->axis_ += in_shape.size();
}
if (scale_shape.size() + scale_param_->axis_ > in_shape.size()) {
MS_LOG(ERROR) << "Scale tensor shape is incorrect.";
return RET_ERROR;
}
scale_param_->outer_size_ = 1;
scale_param_->axis_size_ = 1;
scale_param_->inner_size_ = 1;
for (int i = 0; i < scale_param_->axis_; i++) {
scale_param_->outer_size_ *= in_shape[i];
}
for (size_t i = 0; i < scale_shape.size(); i++) {
if (in_shape[i + scale_param_->axis_] != scale_shape[i]) {
MS_LOG(ERROR) << "Scale tensor shape is incorrect.";
return RET_ERROR;
}
scale_param_->axis_size_ *= in_shape[i + scale_param_->axis_];
}
for (size_t i = scale_param_->axis_ + scale_shape.size(); i < in_shape.size(); i++) {
scale_param_->inner_size_ *= in_shape[i];
tile_para = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (tile_para == nullptr) {
MS_LOG(ERROR) << "malloc tile parameter failed.";
return RET_ERROR;
}
scale_param_->op_parameter_.thread_num_ = MSMIN(scale_param_->op_parameter_.thread_num_, scale_param_->outer_size_);
size_t input0_size = in_tensors_.at(0)->shape().size();
size_t input1_size = in_tensors_.at(1)->shape().size();
size_t output_size = out_tensors_.at(0)->shape().size();
auto input1_shape = in_tensors_.at(1)->shape();
tile_para->ndim_ = output_size;
// supplement shape of scale tensor with number 1
size_t len = input0_size - scale_param_->axis_;
second_in_shape_ = input1_shape;
if (len != input1_size) {
second_in_shape_.resize(len);
size_t i = 0;
for (; i < input1_size; ++i) {
second_in_shape_[i] = input1_shape[i];
}
for (; i < len; ++i) {
second_in_shape_[i] = 1;
}
input1_size = len;
}
if (input0_size == input1_size) {
for (size_t i = 0; i < output_size; i++) {
tile_para->in_shape0_[i] = in_tensors_.at(0)->DimensionSize(i);
tile_para->in_shape1_[i] = in_tensors_.at(1)->DimensionSize(i);
tile_para->out_shape_[i] = out_tensors_.at(0)->DimensionSize(i);
}
} else {
MS_ASSERT(input0_size > input1_size);
size_t fill_dim_num = input0_size - input1_size;
int j = 0;
for (size_t i = 0; i < output_size; i++) {
tile_para->in_shape0_[i] = in_tensors_.at(0)->DimensionSize(i);
if (i < fill_dim_num) {
tile_para->in_shape1_[i] = 1;
} else {
tile_para->in_shape1_[i] = second_in_shape_[j++];
}
tile_para->out_shape_[i] = out_tensors_.at(0)->DimensionSize(i);
}
}
return RET_OK;
}
@ -156,6 +196,24 @@ int ScaleInt8CPUKernel::InitQuantArgs() {
scale_param_->offset_mul_arg_.left_shift_ = shift > 0 ? shift : 0;
scale_param_->offset_mul_arg_.right_shift_ = shift < 0 ? -shift : 0;
}
switch (scale_param_->activation_type_) {
case schema::ActivationType_RELU:
scale_param_->output_activation_min_ = 0;
scale_param_->output_activation_max_ = INT8_MAX;
break;
case schema::ActivationType_RELU6:
scale_param_->output_activation_min_ = 0;
scale_param_->output_activation_max_ = 6;
break;
case schema::ActivationType_NO_ACTIVATION:
scale_param_->output_activation_min_ = INT8_MIN;
scale_param_->output_activation_max_ = INT8_MAX;
break;
default:
MS_LOG(ERROR) << "Scale does not support activation type " << scale_param_->activation_type_;
return RET_ERROR;
}
return RET_OK;
}
@ -176,13 +234,13 @@ int ScaleInt8CPUKernel::Init() {
int ScaleInt8CPUKernel::ReSize() {
auto ret = InitParameter();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Scale fp32 InitParameter failed.";
MS_LOG(ERROR) << "Scale int8 InitParameter failed.";
return RET_ERROR;
}
ret = InitScaleOffset();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Scale fp32 InitScaleOffset failed.";
MS_LOG(ERROR) << "Scale int8 InitScaleOffset failed.";
return RET_ERROR;
}
@ -195,38 +253,21 @@ int ScaleInt8CPUKernel::ReSize() {
}
int ScaleInt8CPUKernel::Scale(int task_id) {
if (has_bias_) {
switch (scale_param_->activation_type_) {
case schema::ActivationType_RELU:
DoScaleWithBiasInt8(input_ptr_, output_ptr_, scale_, offset_, task_id, scale_param_, INT8_MAX, 0);
break;
case schema::ActivationType_RELU6:
DoScaleWithBiasInt8(input_ptr_, output_ptr_, scale_, offset_, task_id, scale_param_, 6, 0);
break;
case schema::ActivationType_NO_ACTIVATION:
DoScaleWithBiasInt8(input_ptr_, output_ptr_, scale_, offset_, task_id, scale_param_, INT8_MAX, INT8_MIN);
break;
default:
MS_LOG(ERROR) << "Scale does not support activation type " << scale_param_->activation_type_;
return RET_ERROR;
}
} else {
switch (scale_param_->activation_type_) {
case schema::ActivationType_RELU:
DoScaleInt8(input_ptr_, output_ptr_, scale_, task_id, scale_param_, INT8_MAX, 0);
break;
case schema::ActivationType_RELU6:
DoScaleInt8(input_ptr_, output_ptr_, scale_, task_id, scale_param_, 6, 0);
break;
case schema::ActivationType_NO_ACTIVATION:
DoScaleInt8(input_ptr_, output_ptr_, scale_, task_id, scale_param_, INT8_MAX, INT8_MIN);
break;
default:
MS_LOG(ERROR) << "Scale does not support activation type " << scale_param_->activation_type_;
return RET_ERROR;
}
int real_dst_count = MSMIN(elements_num_ - task_id * count_unit_, count_unit_);
if (real_dst_count <= 0) {
return lite::RET_OK;
}
int8_t *cur_input0_data = input0_data_ + task_id * count_unit_;
int8_t *cur_input1_data = input1_data_ + task_id * count_unit_;
int8_t *cur_output_data = output_data_ + task_id * count_unit_;
if (has_bias_) {
int8_t *cur_input2_data = input2_data_ + task_id * count_unit_;
DoScaleWithBiasInt8(cur_input0_data, cur_output_data, cur_input1_data, cur_input2_data, scale_param_,
real_dst_count);
} else {
DoScaleInt8(cur_input0_data, cur_output_data, cur_input1_data, scale_param_, real_dst_count);
}
return RET_OK;
}
@ -241,18 +282,59 @@ int ScaleRunInt8(void *cdata, int task_id) {
}
int ScaleInt8CPUKernel::Run() {
auto in_tensor = in_tensors_.front();
input_ptr_ = reinterpret_cast<int8_t *>(in_tensor->data_c());
if (scale_ == nullptr) {
auto scale_tensor = in_tensors_[1];
scale_ = reinterpret_cast<int8_t *>(scale_tensor->data_c());
elements_num_ = out_tensors_.at(0)->ElementsNum();
count_unit_ = thread_count_ > 1 ? UP_DIV(elements_num_, thread_count_) : elements_num_;
input0_data_ = reinterpret_cast<int8_t *>(in_tensors_.at(0)->data_c());
output_data_ = reinterpret_cast<int8_t *>(out_tensors_.at(0)->data_c());
// need broadcasting
if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(1)->ElementsNum()) {
// scale is passed by previous node, need do broadcasting online
if (!scale_param_->const_scale_) {
input1_data_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(out_tensors_.at(0)->Size()));
if (input1_data_ == nullptr) {
MS_LOG(ERROR) << "malloc input1_data_ failed.";
return RET_ERROR;
}
TileOneDimensionUint8(reinterpret_cast<uint8_t *>(in_tensors_.at(1)->data_c()),
reinterpret_cast<uint8_t *>(input1_data_), 0, tile_para->ndim_, tile_para->in_shape1_,
tile_para->in_strides1_, tile_para->out_strides_, tile_para->multiples1_);
}
// If has bias, bias is passed by previous node case, need do broadcasting online
if (has_bias_ && !scale_param_->const_offset_) {
input2_data_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(out_tensors_.at(0)->Size()));
if (input2_data_ == nullptr) {
MS_LOG(ERROR) << "malloc input2_data_ failed.";
ctx_->allocator->Free(input1_data_);
input1_data_ = nullptr;
return RET_ERROR;
}
TileOneDimensionUint8(reinterpret_cast<uint8_t *>(in_tensors_.at(2)->data_c()),
reinterpret_cast<uint8_t *>(input2_data_), 0, tile_para->ndim_, tile_para->in_shape1_,
tile_para->in_strides1_, tile_para->out_strides_, tile_para->multiples1_);
}
auto ret = ParallelLaunch(this->context_->thread_pool_, ScaleRunInt8, this, op_parameter_->thread_num_);
// free memory malloced from memory pool
if (!scale_param_->const_scale_) {
ctx_->allocator->Free(input1_data_);
input1_data_ = nullptr;
}
if (has_bias_ && !scale_param_->const_offset_) {
offset_ = reinterpret_cast<int8_t *>(in_tensors_.at(2)->data_c());
ctx_->allocator->Free(input2_data_);
input2_data_ = nullptr;
}
return ret;
}
auto out_tensor = out_tensors_.front();
output_ptr_ = reinterpret_cast<int8_t *>(out_tensor->data_c());
// input1 has the same shape with input0 situation
if (input1_data_ == nullptr) {
input1_data_ = reinterpret_cast<int8_t *>(in_tensors_.at(1)->data_c());
}
if (has_bias_ && !scale_param_->const_offset_) {
input2_data_ = reinterpret_cast<int8_t *>(in_tensors_.at(2)->data_c());
}
auto ret = ParallelLaunch(this->context_->thread_pool_, ScaleRunInt8, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Scale error error_code[" << ret << "]";
@ -260,6 +342,7 @@ int ScaleInt8CPUKernel::Run() {
}
return RET_OK;
}
kernel::LiteKernel *CpuScaleInt8KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc,

View File

@ -21,6 +21,7 @@
#include "src/lite_kernel.h"
#include "nnacl/scale.h"
#include "nnacl/quantization/quantize.h"
#include "nnacl/arithmetic_common.h"
namespace mindspore::kernel {
@ -29,7 +30,7 @@ class ScaleInt8CPUKernel : public LiteKernel {
ScaleInt8CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
: LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx_->thread_num_) {
scale_param_ = reinterpret_cast<ScaleParameter *>(op_parameter_);
}
~ScaleInt8CPUKernel() override;
@ -42,12 +43,20 @@ class ScaleInt8CPUKernel : public LiteKernel {
int Scale(int task_id);
private:
int8_t *input_ptr_ = nullptr;
int8_t *scale_ = nullptr;
int8_t *offset_ = nullptr;
int8_t *output_ptr_ = nullptr;
bool has_bias_ = false;
int8_t *input0_data_ = nullptr;
int8_t *input1_data_ = nullptr;
int8_t *input2_data_ = nullptr;
int8_t *output_data_ = nullptr;
const lite::InnerContext *ctx_;
ScaleParameter *scale_param_;
ArithmeticParameter *tile_para = nullptr;
std::vector<int> second_in_shape_;
int thread_count_;
int64_t elements_num_;
int64_t count_unit_;
bool has_bias_ = false;
bool malloced_scale_ = false;
bool malloced_offset_ = false;
int InitQuantArgs();
};