forked from mindspore-Ecosystem/mindspore
!8148 [MS][LITE][CPU]optimize int8 scale op
Merge pull request !8148 from fuzhiye/tmp
This commit is contained in:
commit
c47d0d68fe
|
@ -53,6 +53,8 @@ void ComputeStrides(const int *shape, int *strides, const int ndim);
|
|||
|
||||
void CalcMultiplesAndStrides(ArithmeticParameter *param);
|
||||
|
||||
void TileOneDimensionUint8(uint8_t *inData, uint8_t *outData, int dim, size_t ndim, int *inShape, int *inStrides,
|
||||
int *outStrides, int *multiple);
|
||||
void TileDimensions(float *data0, float *data1, float *tile_data0, float *tile_data1, ArithmeticParameter *param);
|
||||
void TileDimensionsUint8(uint8_t *data0, uint8_t *data1, uint8_t *tile_data0, uint8_t *tile_data1,
|
||||
ArithmeticParameter *param);
|
||||
|
|
|
@ -17,78 +17,148 @@
|
|||
#include "nnacl/int8/scale_int8.h"
|
||||
#include "nnacl/quantization/fixed_point.h"
|
||||
|
||||
void ScaleInnerInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, int outer_start, int outer_end,
|
||||
int axis_size, int inner_size, const ScaleParameter *scale_param, int max, int min) {
|
||||
for (int out = outer_start; out < outer_end; out++) {
|
||||
int out_offset = out * axis_size * inner_size;
|
||||
for (int i = 0; i < axis_size; i++) {
|
||||
int axis_offset = out_offset + i * inner_size;
|
||||
int in_index = 0;
|
||||
|
||||
for (; in_index < inner_size; in_index++) {
|
||||
int in_offset = axis_offset + in_index;
|
||||
int tmp_input_scale = (in_data[in_offset] - scale_param->input_zp_) * (scale[i] - scale_param->scale_zp_);
|
||||
int input_mul_scale =
|
||||
RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
|
||||
tmp_input_scale * (1 << (unsigned int)scale_param->scale_mul_arg_.left_shift_),
|
||||
scale_param->scale_mul_arg_.multiplier_),
|
||||
scale_param->scale_mul_arg_.right_shift_);
|
||||
int tmp = input_mul_scale + scale_param->output_zp_;
|
||||
tmp = tmp > max ? max : tmp;
|
||||
tmp = tmp < min ? min : tmp;
|
||||
out_data[in_offset] = tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
#ifdef ENABLE_NEON
|
||||
int16x4_t ClacSumHalfWordMul2(int32x4_t scaled_input0, int32x4_t scaled_input1, int32x4_t left_shift_out_vec,
|
||||
int32x4_t output_multiplier_vec, const ScaleParameter *scale_param) {
|
||||
int32x4_t input_scale = vmulq_s32(scaled_input0, scaled_input1);
|
||||
int32x4_t raw_sum = RoundingDivideByPOTInt32x4(
|
||||
SaturatingRoundingDoublingHighMulInt32x4(vmulq_s32(input_scale, left_shift_out_vec), output_multiplier_vec),
|
||||
scale_param->scale_mul_arg_.right_shift_);
|
||||
raw_sum = vaddq_s32(raw_sum, vdupq_n_s32(scale_param->output_zp_));
|
||||
raw_sum = vmaxq_s32(raw_sum, vdupq_n_s32(scale_param->output_activation_min_));
|
||||
raw_sum = vminq_s32(raw_sum, vdupq_n_s32(scale_param->output_activation_max_));
|
||||
return vqmovn_s32(raw_sum);
|
||||
}
|
||||
|
||||
void ScaleInnerWithBiasInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const int8_t *offset,
|
||||
int outer_start, int outer_end, int axis_size, int inner_size,
|
||||
const ScaleParameter *scale_param, int max, int min) {
|
||||
for (int out = outer_start; out < outer_end; out++) {
|
||||
int out_offset = out * axis_size * inner_size;
|
||||
for (int i = 0; i < axis_size; i++) {
|
||||
int axis_offset = out_offset + i * inner_size;
|
||||
int in_index = 0;
|
||||
int16x4_t ClacSumHalfWordMul3(int32x4_t scaled_input0, int32x4_t scaled_input1, int32x4_t scaled_input2,
|
||||
const ScaleParameter *scale_param) {
|
||||
int32x4_t output_multiplier_vec = vdupq_n_s32(scale_param->scale_mul_arg_.multiplier_);
|
||||
int32x4_t output_multiplier_vec2 = vdupq_n_s32(scale_param->offset_mul_arg_.multiplier_);
|
||||
int32x4_t left_shift_out_vec = vdupq_n_s32(1 << scale_param->scale_mul_arg_.left_shift_);
|
||||
int32x4_t left_shift_out_vec2 = vdupq_n_s32(1 << scale_param->offset_mul_arg_.left_shift_);
|
||||
int32x4_t input_scale = vmulq_s32(scaled_input0, scaled_input1);
|
||||
int32x4_t raw_sum = RoundingDivideByPOTInt32x4(
|
||||
SaturatingRoundingDoublingHighMulInt32x4(vmulq_s32(input_scale, left_shift_out_vec), output_multiplier_vec),
|
||||
scale_param->scale_mul_arg_.right_shift_);
|
||||
int32x4_t raw_sum2 = RoundingDivideByPOTInt32x4(
|
||||
SaturatingRoundingDoublingHighMulInt32x4(vmulq_s32(scaled_input2, left_shift_out_vec2), output_multiplier_vec2),
|
||||
scale_param->offset_mul_arg_.right_shift_);
|
||||
raw_sum = vaddq_s32(raw_sum, vdupq_n_s32(scale_param->output_zp_));
|
||||
raw_sum = vaddq_s32(raw_sum, raw_sum2);
|
||||
raw_sum = vmaxq_s32(raw_sum, vdupq_n_s32(scale_param->output_activation_min_));
|
||||
raw_sum = vminq_s32(raw_sum, vdupq_n_s32(scale_param->output_activation_max_));
|
||||
return vqmovn_s32(raw_sum);
|
||||
}
|
||||
#endif
|
||||
|
||||
for (; in_index < inner_size; in_index++) {
|
||||
int in_offset = axis_offset + in_index;
|
||||
int tmp_input_scale = (in_data[in_offset] - scale_param->input_zp_) * (scale[i] - scale_param->scale_zp_);
|
||||
int input_mul_scale =
|
||||
RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
|
||||
tmp_input_scale * (1 << (unsigned int)scale_param->scale_mul_arg_.left_shift_),
|
||||
scale_param->scale_mul_arg_.multiplier_),
|
||||
scale_param->scale_mul_arg_.right_shift_);
|
||||
int tmp_bias = offset[i] - scale_param->offset_zp_;
|
||||
int bias = RoundingDivideByPOT(
|
||||
SaturatingRoundingDoublingHighMul(tmp_bias * (1 << (unsigned int)scale_param->offset_mul_arg_.left_shift_),
|
||||
scale_param->offset_mul_arg_.multiplier_),
|
||||
scale_param->offset_mul_arg_.right_shift_);
|
||||
int tmp = input_mul_scale + bias + scale_param->output_zp_;
|
||||
tmp = tmp > max ? max : tmp;
|
||||
tmp = tmp < min ? min : tmp;
|
||||
out_data[in_offset] = tmp;
|
||||
}
|
||||
void DoScaleInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const ScaleParameter *scale_param,
|
||||
int real_dst_count) {
|
||||
int index = 0;
|
||||
#ifdef ENABLE_NEON
|
||||
int32x4_t output_multiplier_vec = vdupq_n_s32(scale_param->scale_mul_arg_.multiplier_);
|
||||
int32x4_t left_shift_out_vec = vdupq_n_s32(1 << scale_param->scale_mul_arg_.left_shift_);
|
||||
|
||||
for (; index <= real_dst_count - 8; index += 8) {
|
||||
int8x8_t input_s8 = vld1_s8(in_data + index);
|
||||
int16x8_t input_s16 = vmovl_s8(input_s8);
|
||||
int16x8_t input0_val = vaddq_s16(input_s16, vdupq_n_s16(scale_param->input_zp_));
|
||||
|
||||
int8x8_t input1_s8 = vld1_s8(scale + index);
|
||||
int16x8_t input1_s16 = vmovl_s8(input1_s8);
|
||||
int16x8_t input1_val = vaddq_s16(input1_s16, vdupq_n_s16(scale_param->scale_zp_));
|
||||
|
||||
int32x4_t input0_low = vmovl_s16(vget_low_s16(input0_val));
|
||||
int32x4_t input0_high = vmovl_s16(vget_high_s16(input0_val));
|
||||
int32x4_t input1_low = vmovl_s16(vget_low_s16(input1_val));
|
||||
int32x4_t input1_high = vmovl_s16(vget_high_s16(input1_val));
|
||||
|
||||
int16x4_t sum_low =
|
||||
ClacSumHalfWordMul2(input0_low, input1_low, left_shift_out_vec, output_multiplier_vec, scale_param);
|
||||
int16x4_t sum_high =
|
||||
ClacSumHalfWordMul2(input0_high, input1_high, left_shift_out_vec, output_multiplier_vec, scale_param);
|
||||
|
||||
int16x8_t res_s16 = vcombine_s16(sum_low, sum_high);
|
||||
int8x8_t res_u8_n0 = vqmovn_s16(res_s16);
|
||||
vst1_s8(out_data, res_u8_n0);
|
||||
out_data += 8;
|
||||
}
|
||||
#endif
|
||||
for (; index < real_dst_count; ++index) {
|
||||
const int32_t input0_val = scale_param->input_zp_ + in_data[index];
|
||||
const int32_t input1_val = scale_param->scale_zp_ + scale[index];
|
||||
int32_t mul_result = RoundingDivideByPOT(
|
||||
SaturatingRoundingDoublingHighMul(input0_val * input1_val * (1 << scale_param->scale_mul_arg_.left_shift_),
|
||||
scale_param->scale_mul_arg_.multiplier_),
|
||||
scale_param->scale_mul_arg_.right_shift_);
|
||||
|
||||
mul_result += scale_param->output_zp_;
|
||||
|
||||
if (mul_result > scale_param->output_activation_max_) {
|
||||
out_data[index] = scale_param->output_activation_max_;
|
||||
} else if (mul_result < scale_param->output_activation_min_) {
|
||||
out_data[index] = scale_param->output_activation_min_;
|
||||
} else {
|
||||
out_data[index] = (int8_t)mul_result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DoScaleInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, int task_id,
|
||||
const ScaleParameter *scale_param, int max, int min) {
|
||||
int outer_step = UP_DIV(scale_param->outer_size_, scale_param->op_parameter_.thread_num_);
|
||||
int outer_start = task_id * outer_step;
|
||||
int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_);
|
||||
|
||||
ScaleInnerInt8(in_data, out_data, scale, outer_start, outer_end, scale_param->axis_size_, scale_param->inner_size_,
|
||||
scale_param, max, min);
|
||||
return;
|
||||
}
|
||||
|
||||
void DoScaleWithBiasInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const int8_t *offset,
|
||||
int task_id, const ScaleParameter *scale_param, int max, int min) {
|
||||
int outer_step = UP_DIV(scale_param->outer_size_, scale_param->op_parameter_.thread_num_);
|
||||
int outer_start = task_id * outer_step;
|
||||
int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_);
|
||||
const ScaleParameter *scale_param, int real_dst_count) {
|
||||
int index = 0;
|
||||
#ifdef ENABLE_NEON
|
||||
for (; index <= real_dst_count - 8; index += 8) {
|
||||
int8x8_t input_s8 = vld1_s8(in_data + index);
|
||||
int16x8_t input_s16 = vmovl_s8(input_s8);
|
||||
int16x8_t input0_val = vaddq_s16(input_s16, vdupq_n_s16(scale_param->input_zp_));
|
||||
|
||||
ScaleInnerWithBiasInt8(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_,
|
||||
scale_param->inner_size_, scale_param, max, min);
|
||||
int8x8_t input1_s8 = vld1_s8(scale + index);
|
||||
int16x8_t input1_s16 = vmovl_s8(input1_s8);
|
||||
int16x8_t input1_val = vaddq_s16(input1_s16, vdupq_n_s16(scale_param->scale_zp_));
|
||||
|
||||
int8x8_t input2_s8 = vld1_s8(offset + index);
|
||||
int16x8_t input2_s16 = vmovl_s8(input2_s8);
|
||||
int16x8_t input2_val = vaddq_s16(input2_s16, vdupq_n_s16(scale_param->offset_zp_));
|
||||
|
||||
int32x4_t input0_low = vmovl_s16(vget_low_s16(input0_val));
|
||||
int32x4_t input0_high = vmovl_s16(vget_high_s16(input0_val));
|
||||
int32x4_t input1_low = vmovl_s16(vget_low_s16(input1_val));
|
||||
int32x4_t input1_high = vmovl_s16(vget_high_s16(input1_val));
|
||||
int32x4_t input2_low = vmovl_s16(vget_low_s16(input2_val));
|
||||
int32x4_t input2_high = vmovl_s16(vget_high_s16(input2_val));
|
||||
|
||||
int16x4_t sum_low = ClacSumHalfWordMul3(input0_low, input1_low, input2_low, scale_param);
|
||||
int16x4_t sum_high = ClacSumHalfWordMul3(input0_high, input1_high, input2_high, scale_param);
|
||||
|
||||
int16x8_t res_s16 = vcombine_s16(sum_low, sum_high);
|
||||
int8x8_t res_u8_n0 = vqmovn_s16(res_s16);
|
||||
vst1_s8(out_data, res_u8_n0);
|
||||
out_data += 8;
|
||||
}
|
||||
#endif
|
||||
for (; index < real_dst_count; ++index) {
|
||||
const int32_t input0_val = in_data[index] - scale_param->input_zp_;
|
||||
const int32_t input1_val = scale[index] - scale_param->scale_zp_;
|
||||
int32_t mul_result = RoundingDivideByPOT(
|
||||
SaturatingRoundingDoublingHighMul(input0_val * input1_val * (1 << scale_param->scale_mul_arg_.left_shift_),
|
||||
scale_param->scale_mul_arg_.multiplier_),
|
||||
scale_param->scale_mul_arg_.right_shift_);
|
||||
int tmp_bias = offset[index] - scale_param->offset_zp_;
|
||||
int bias = RoundingDivideByPOT(
|
||||
SaturatingRoundingDoublingHighMul(tmp_bias * (1 << (unsigned int)scale_param->offset_mul_arg_.left_shift_),
|
||||
scale_param->offset_mul_arg_.multiplier_),
|
||||
scale_param->offset_mul_arg_.right_shift_);
|
||||
|
||||
mul_result += bias + scale_param->output_zp_;
|
||||
|
||||
if (mul_result > scale_param->output_activation_max_) {
|
||||
out_data[index] = scale_param->output_activation_max_;
|
||||
} else if (mul_result < scale_param->output_activation_min_) {
|
||||
out_data[index] = scale_param->output_activation_min_;
|
||||
} else {
|
||||
out_data[index] = (int8_t)mul_result;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -22,10 +22,10 @@
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void DoScaleInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, int task_id,
|
||||
const ScaleParameter *scale_param, int max, int min);
|
||||
void DoScaleInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const ScaleParameter *scale_param,
|
||||
int real_dst_count);
|
||||
void DoScaleWithBiasInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const int8_t *offset,
|
||||
int task_id, const ScaleParameter *scale_param, int max, int min);
|
||||
const ScaleParameter *scale_param, int real_dst_count);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -34,6 +34,8 @@ typedef struct ScaleParameter {
|
|||
int offset_zp_;
|
||||
int output_zp_;
|
||||
int activation_type_;
|
||||
int output_activation_min_;
|
||||
int output_activation_max_;
|
||||
} ScaleParameter;
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_SCALE_H_
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <string.h>
|
||||
#include <vector>
|
||||
#include "nnacl/int8/scale_int8.h"
|
||||
#include "nnacl/arithmetic_common.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
|
@ -35,63 +36,65 @@ constexpr size_t kScaleInputsSize = 2;
|
|||
constexpr size_t kScaleBiasInputsSize = 3;
|
||||
} // namespace
|
||||
ScaleInt8CPUKernel::~ScaleInt8CPUKernel() {
|
||||
if (scale_param_->const_scale_) {
|
||||
if (scale_ != nullptr) {
|
||||
free(scale_);
|
||||
scale_ = nullptr;
|
||||
}
|
||||
if (tile_para != nullptr) {
|
||||
free(tile_para);
|
||||
tile_para = nullptr;
|
||||
}
|
||||
if (has_bias_ && scale_param_->const_offset_) {
|
||||
if (offset_ != nullptr) {
|
||||
free(offset_);
|
||||
offset_ = nullptr;
|
||||
}
|
||||
if (input1_data_ != nullptr && malloced_scale_) {
|
||||
free(input1_data_);
|
||||
}
|
||||
if (input2_data_ != nullptr && malloced_offset_) {
|
||||
free(input2_data_);
|
||||
}
|
||||
}
|
||||
|
||||
int ScaleInt8CPUKernel::InitScaleOffset() {
|
||||
auto scale_tensor = in_tensors_.at(1);
|
||||
int8_t *scale_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(1)->data_c());
|
||||
CalcMultiplesAndStrides(tile_para);
|
||||
scale_param_->const_scale_ = false;
|
||||
auto *scale_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(1)->data_c());
|
||||
// scale may be const value ,can be processed in prepare stage
|
||||
if (scale_ptr != nullptr) {
|
||||
scale_param_->const_scale_ = true;
|
||||
if (scale_ != nullptr) {
|
||||
free(scale_);
|
||||
scale_ = nullptr;
|
||||
input1_data_ = scale_ptr;
|
||||
// need broadcasting
|
||||
if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(1)->ElementsNum()) {
|
||||
input1_data_ = reinterpret_cast<int8_t *>(malloc(out_tensors_.at(0)->Size()));
|
||||
if (input1_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc input1_data_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
malloced_scale_ = true;
|
||||
TileOneDimensionUint8(reinterpret_cast<uint8_t *>(in_tensors_.at(1)->data_c()),
|
||||
reinterpret_cast<uint8_t *>(input1_data_), 0, tile_para->ndim_, tile_para->in_shape1_,
|
||||
tile_para->in_strides1_, tile_para->out_strides_, tile_para->multiples1_);
|
||||
}
|
||||
scale_ = reinterpret_cast<int8_t *>(malloc(scale_tensor->ElementsNum() * sizeof(int8_t)));
|
||||
if (scale_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memcpy(scale_, scale_ptr, scale_tensor->ElementsNum() * sizeof(int8_t));
|
||||
} else {
|
||||
scale_param_->const_scale_ = false;
|
||||
scale_ = nullptr;
|
||||
}
|
||||
|
||||
scale_param_->const_offset_ = false;
|
||||
if (in_tensors_.size() == 3) {
|
||||
has_bias_ = true;
|
||||
auto offset_tensor = in_tensors_.at(2);
|
||||
int8_t *offset_ptr = reinterpret_cast<int8_t *>(offset_tensor->data_c());
|
||||
auto *offset_ptr = reinterpret_cast<int8_t *>(offset_tensor->data_c());
|
||||
// offset may be const value ,can be processed in prepare stage
|
||||
if (offset_ptr != nullptr) {
|
||||
scale_param_->const_offset_ = true;
|
||||
if (offset_ != nullptr) {
|
||||
free(offset_);
|
||||
offset_ = nullptr;
|
||||
input2_data_ = offset_ptr;
|
||||
// need broadcasting
|
||||
if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(2)->ElementsNum()) {
|
||||
input2_data_ = reinterpret_cast<int8_t *>(malloc(out_tensors_.at(0)->Size()));
|
||||
if (input2_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc input2_data_ failed.";
|
||||
free(input1_data_);
|
||||
return RET_ERROR;
|
||||
}
|
||||
malloced_offset_ = true;
|
||||
TileOneDimensionUint8(reinterpret_cast<uint8_t *>(in_tensors_.at(2)->data_c()),
|
||||
reinterpret_cast<uint8_t *>(input2_data_), 0, tile_para->ndim_, tile_para->in_shape1_,
|
||||
tile_para->in_strides1_, tile_para->out_strides_, tile_para->multiples1_);
|
||||
}
|
||||
offset_ = reinterpret_cast<int8_t *>(malloc(offset_tensor->ElementsNum() * sizeof(int8_t)));
|
||||
if (offset_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memcpy(offset_, offset_ptr, offset_tensor->ElementsNum() * sizeof(int8_t));
|
||||
} else {
|
||||
scale_param_->const_offset_ = false;
|
||||
offset_ = nullptr;
|
||||
}
|
||||
} else {
|
||||
has_bias_ = false;
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -102,29 +105,66 @@ int ScaleInt8CPUKernel::InitParameter() {
|
|||
auto scale_shape = scale_tensor->shape();
|
||||
|
||||
if (scale_param_->axis_ < 0) {
|
||||
scale_param_->axis_ = scale_param_->axis_ + in_shape.size();
|
||||
scale_param_->axis_ += in_shape.size();
|
||||
}
|
||||
if (scale_shape.size() + scale_param_->axis_ > in_shape.size()) {
|
||||
MS_LOG(ERROR) << "Scale tensor shape is incorrect.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
scale_param_->outer_size_ = 1;
|
||||
scale_param_->axis_size_ = 1;
|
||||
scale_param_->inner_size_ = 1;
|
||||
for (int i = 0; i < scale_param_->axis_; i++) {
|
||||
scale_param_->outer_size_ *= in_shape[i];
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < scale_shape.size(); i++) {
|
||||
if (in_shape[i + scale_param_->axis_] != scale_shape[i]) {
|
||||
MS_LOG(ERROR) << "Scale tensor shape is incorrect.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
scale_param_->axis_size_ *= in_shape[i + scale_param_->axis_];
|
||||
}
|
||||
for (size_t i = scale_param_->axis_ + scale_shape.size(); i < in_shape.size(); i++) {
|
||||
scale_param_->inner_size_ *= in_shape[i];
|
||||
|
||||
tile_para = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
|
||||
if (tile_para == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc tile parameter failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
scale_param_->op_parameter_.thread_num_ = MSMIN(scale_param_->op_parameter_.thread_num_, scale_param_->outer_size_);
|
||||
size_t input0_size = in_tensors_.at(0)->shape().size();
|
||||
size_t input1_size = in_tensors_.at(1)->shape().size();
|
||||
size_t output_size = out_tensors_.at(0)->shape().size();
|
||||
auto input1_shape = in_tensors_.at(1)->shape();
|
||||
tile_para->ndim_ = output_size;
|
||||
// supplement shape of scale tensor with number 1
|
||||
size_t len = input0_size - scale_param_->axis_;
|
||||
second_in_shape_ = input1_shape;
|
||||
if (len != input1_size) {
|
||||
second_in_shape_.resize(len);
|
||||
size_t i = 0;
|
||||
for (; i < input1_size; ++i) {
|
||||
second_in_shape_[i] = input1_shape[i];
|
||||
}
|
||||
for (; i < len; ++i) {
|
||||
second_in_shape_[i] = 1;
|
||||
}
|
||||
input1_size = len;
|
||||
}
|
||||
|
||||
if (input0_size == input1_size) {
|
||||
for (size_t i = 0; i < output_size; i++) {
|
||||
tile_para->in_shape0_[i] = in_tensors_.at(0)->DimensionSize(i);
|
||||
tile_para->in_shape1_[i] = in_tensors_.at(1)->DimensionSize(i);
|
||||
tile_para->out_shape_[i] = out_tensors_.at(0)->DimensionSize(i);
|
||||
}
|
||||
} else {
|
||||
MS_ASSERT(input0_size > input1_size);
|
||||
size_t fill_dim_num = input0_size - input1_size;
|
||||
int j = 0;
|
||||
for (size_t i = 0; i < output_size; i++) {
|
||||
tile_para->in_shape0_[i] = in_tensors_.at(0)->DimensionSize(i);
|
||||
if (i < fill_dim_num) {
|
||||
tile_para->in_shape1_[i] = 1;
|
||||
} else {
|
||||
tile_para->in_shape1_[i] = second_in_shape_[j++];
|
||||
}
|
||||
tile_para->out_shape_[i] = out_tensors_.at(0)->DimensionSize(i);
|
||||
}
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -156,6 +196,24 @@ int ScaleInt8CPUKernel::InitQuantArgs() {
|
|||
scale_param_->offset_mul_arg_.left_shift_ = shift > 0 ? shift : 0;
|
||||
scale_param_->offset_mul_arg_.right_shift_ = shift < 0 ? -shift : 0;
|
||||
}
|
||||
|
||||
switch (scale_param_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
scale_param_->output_activation_min_ = 0;
|
||||
scale_param_->output_activation_max_ = INT8_MAX;
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
scale_param_->output_activation_min_ = 0;
|
||||
scale_param_->output_activation_max_ = 6;
|
||||
break;
|
||||
case schema::ActivationType_NO_ACTIVATION:
|
||||
scale_param_->output_activation_min_ = INT8_MIN;
|
||||
scale_param_->output_activation_max_ = INT8_MAX;
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Scale does not support activation type " << scale_param_->activation_type_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -176,13 +234,13 @@ int ScaleInt8CPUKernel::Init() {
|
|||
int ScaleInt8CPUKernel::ReSize() {
|
||||
auto ret = InitParameter();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Scale fp32 InitParameter failed.";
|
||||
MS_LOG(ERROR) << "Scale int8 InitParameter failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
ret = InitScaleOffset();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Scale fp32 InitScaleOffset failed.";
|
||||
MS_LOG(ERROR) << "Scale int8 InitScaleOffset failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
|
@ -195,38 +253,21 @@ int ScaleInt8CPUKernel::ReSize() {
|
|||
}
|
||||
|
||||
int ScaleInt8CPUKernel::Scale(int task_id) {
|
||||
if (has_bias_) {
|
||||
switch (scale_param_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
DoScaleWithBiasInt8(input_ptr_, output_ptr_, scale_, offset_, task_id, scale_param_, INT8_MAX, 0);
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
DoScaleWithBiasInt8(input_ptr_, output_ptr_, scale_, offset_, task_id, scale_param_, 6, 0);
|
||||
break;
|
||||
case schema::ActivationType_NO_ACTIVATION:
|
||||
DoScaleWithBiasInt8(input_ptr_, output_ptr_, scale_, offset_, task_id, scale_param_, INT8_MAX, INT8_MIN);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Scale does not support activation type " << scale_param_->activation_type_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
switch (scale_param_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
DoScaleInt8(input_ptr_, output_ptr_, scale_, task_id, scale_param_, INT8_MAX, 0);
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
DoScaleInt8(input_ptr_, output_ptr_, scale_, task_id, scale_param_, 6, 0);
|
||||
break;
|
||||
case schema::ActivationType_NO_ACTIVATION:
|
||||
DoScaleInt8(input_ptr_, output_ptr_, scale_, task_id, scale_param_, INT8_MAX, INT8_MIN);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Scale does not support activation type " << scale_param_->activation_type_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
int real_dst_count = MSMIN(elements_num_ - task_id * count_unit_, count_unit_);
|
||||
if (real_dst_count <= 0) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
int8_t *cur_input0_data = input0_data_ + task_id * count_unit_;
|
||||
int8_t *cur_input1_data = input1_data_ + task_id * count_unit_;
|
||||
int8_t *cur_output_data = output_data_ + task_id * count_unit_;
|
||||
|
||||
if (has_bias_) {
|
||||
int8_t *cur_input2_data = input2_data_ + task_id * count_unit_;
|
||||
DoScaleWithBiasInt8(cur_input0_data, cur_output_data, cur_input1_data, cur_input2_data, scale_param_,
|
||||
real_dst_count);
|
||||
} else {
|
||||
DoScaleInt8(cur_input0_data, cur_output_data, cur_input1_data, scale_param_, real_dst_count);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -241,18 +282,59 @@ int ScaleRunInt8(void *cdata, int task_id) {
|
|||
}
|
||||
|
||||
int ScaleInt8CPUKernel::Run() {
|
||||
auto in_tensor = in_tensors_.front();
|
||||
input_ptr_ = reinterpret_cast<int8_t *>(in_tensor->data_c());
|
||||
if (scale_ == nullptr) {
|
||||
auto scale_tensor = in_tensors_[1];
|
||||
scale_ = reinterpret_cast<int8_t *>(scale_tensor->data_c());
|
||||
elements_num_ = out_tensors_.at(0)->ElementsNum();
|
||||
count_unit_ = thread_count_ > 1 ? UP_DIV(elements_num_, thread_count_) : elements_num_;
|
||||
input0_data_ = reinterpret_cast<int8_t *>(in_tensors_.at(0)->data_c());
|
||||
output_data_ = reinterpret_cast<int8_t *>(out_tensors_.at(0)->data_c());
|
||||
|
||||
// need broadcasting
|
||||
if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(1)->ElementsNum()) {
|
||||
// scale is passed by previous node, need do broadcasting online
|
||||
if (!scale_param_->const_scale_) {
|
||||
input1_data_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(out_tensors_.at(0)->Size()));
|
||||
if (input1_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc input1_data_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
TileOneDimensionUint8(reinterpret_cast<uint8_t *>(in_tensors_.at(1)->data_c()),
|
||||
reinterpret_cast<uint8_t *>(input1_data_), 0, tile_para->ndim_, tile_para->in_shape1_,
|
||||
tile_para->in_strides1_, tile_para->out_strides_, tile_para->multiples1_);
|
||||
}
|
||||
|
||||
// If has bias, bias is passed by previous node case, need do broadcasting online
|
||||
if (has_bias_ && !scale_param_->const_offset_) {
|
||||
input2_data_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(out_tensors_.at(0)->Size()));
|
||||
if (input2_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc input2_data_ failed.";
|
||||
ctx_->allocator->Free(input1_data_);
|
||||
input1_data_ = nullptr;
|
||||
return RET_ERROR;
|
||||
}
|
||||
TileOneDimensionUint8(reinterpret_cast<uint8_t *>(in_tensors_.at(2)->data_c()),
|
||||
reinterpret_cast<uint8_t *>(input2_data_), 0, tile_para->ndim_, tile_para->in_shape1_,
|
||||
tile_para->in_strides1_, tile_para->out_strides_, tile_para->multiples1_);
|
||||
}
|
||||
|
||||
auto ret = ParallelLaunch(this->context_->thread_pool_, ScaleRunInt8, this, op_parameter_->thread_num_);
|
||||
// free memory malloced from memory pool
|
||||
if (!scale_param_->const_scale_) {
|
||||
ctx_->allocator->Free(input1_data_);
|
||||
input1_data_ = nullptr;
|
||||
}
|
||||
if (has_bias_ && !scale_param_->const_offset_) {
|
||||
ctx_->allocator->Free(input2_data_);
|
||||
input2_data_ = nullptr;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
// input1 has the same shape with input0 situation
|
||||
if (input1_data_ == nullptr) {
|
||||
input1_data_ = reinterpret_cast<int8_t *>(in_tensors_.at(1)->data_c());
|
||||
}
|
||||
if (has_bias_ && !scale_param_->const_offset_) {
|
||||
offset_ = reinterpret_cast<int8_t *>(in_tensors_.at(2)->data_c());
|
||||
input2_data_ = reinterpret_cast<int8_t *>(in_tensors_.at(2)->data_c());
|
||||
}
|
||||
auto out_tensor = out_tensors_.front();
|
||||
output_ptr_ = reinterpret_cast<int8_t *>(out_tensor->data_c());
|
||||
|
||||
auto ret = ParallelLaunch(this->context_->thread_pool_, ScaleRunInt8, this, op_parameter_->thread_num_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Scale error error_code[" << ret << "]";
|
||||
|
@ -260,6 +342,7 @@ int ScaleInt8CPUKernel::Run() {
|
|||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *CpuScaleInt8KernelCreator(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
|
||||
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "src/lite_kernel.h"
|
||||
#include "nnacl/scale.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "nnacl/arithmetic_common.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
|
@ -29,7 +30,7 @@ class ScaleInt8CPUKernel : public LiteKernel {
|
|||
ScaleInt8CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx_->thread_num_) {
|
||||
scale_param_ = reinterpret_cast<ScaleParameter *>(op_parameter_);
|
||||
}
|
||||
~ScaleInt8CPUKernel() override;
|
||||
|
@ -42,12 +43,20 @@ class ScaleInt8CPUKernel : public LiteKernel {
|
|||
int Scale(int task_id);
|
||||
|
||||
private:
|
||||
int8_t *input_ptr_ = nullptr;
|
||||
int8_t *scale_ = nullptr;
|
||||
int8_t *offset_ = nullptr;
|
||||
int8_t *output_ptr_ = nullptr;
|
||||
bool has_bias_ = false;
|
||||
int8_t *input0_data_ = nullptr;
|
||||
int8_t *input1_data_ = nullptr;
|
||||
int8_t *input2_data_ = nullptr;
|
||||
int8_t *output_data_ = nullptr;
|
||||
const lite::InnerContext *ctx_;
|
||||
ScaleParameter *scale_param_;
|
||||
ArithmeticParameter *tile_para = nullptr;
|
||||
std::vector<int> second_in_shape_;
|
||||
int thread_count_;
|
||||
int64_t elements_num_;
|
||||
int64_t count_unit_;
|
||||
bool has_bias_ = false;
|
||||
bool malloced_scale_ = false;
|
||||
bool malloced_offset_ = false;
|
||||
|
||||
int InitQuantArgs();
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue